import os import json import argparse import logging from tqdm import tqdm from typing import List, Dict import torch import torch.distributed as dist from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModelForSequenceClassification # =========================== # PATH RESOLUTION (NO HARDCODE) # =========================== SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_PATH = os.path.join(SCRIPT_DIR, "model") LABEL_MAP_PATH = os.path.join(SCRIPT_DIR, "label_to_id.json") # =========================== # Logging # =========================== def setup_logging(output_dir): os.makedirs(output_dir, exist_ok=True) log_path = os.path.join(output_dir, "language_classifier.log") logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", handlers=[ logging.FileHandler(log_path), logging.StreamHandler() ], ) logging.info(f"Logging to: {log_path}") # =========================== # DDP SETUP # =========================== def setup_distributed(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: dist.init_process_group(backend="nccl") rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) return True, rank, world_size, local_rank return False, 0, 1, 0 def is_main_process(): return ( not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 ) # =========================== # Input Discovery # =========================== def find_all_jsonl_files(path: str) -> List[str]: if os.path.isfile(path): if not path.endswith(".jsonl"): raise ValueError(f"Input file must be .jsonl: {path}") return [path] if not os.path.isdir(path): raise ValueError(f"Input path does not exist: {path}") files = [] for root, _, filenames in os.walk(path): for fn in filenames: if fn.endswith(".jsonl"): files.append(os.path.join(root, fn)) if not files: raise RuntimeError(f"No .jsonl files found inside: {path}") return sorted(files) # =========================== # Dataset (Streaming, DDP-safe) # =========================== class JsonlIterableDataset(torch.utils.data.IterableDataset): def __init__(self, input_path: str, text_key: str, rank: int, world_size: int): self.files = find_all_jsonl_files(input_path) self.text_key = text_key self.rank = rank self.world_size = world_size def __iter__(self): worker_info = torch.utils.data.get_worker_info() worker_id = worker_info.id if worker_info else 0 num_workers = worker_info.num_workers if worker_info else 1 global_worker_id = self.rank * num_workers + worker_id global_num_workers = self.world_size * num_workers json_loads = json.loads text_key = self.text_key for path in self.files: with open(path, "r", encoding="utf-8", errors="ignore") as f: i = 0 for line in f: if i == global_worker_id: try: obj = json_loads(line) except json.JSONDecodeError: pass else: text = obj.get(text_key) if isinstance(text, str) and text.strip(): obj["__lc_text"] = text yield obj i += 1 if i == global_num_workers: i = 0 # =========================== # Collator # =========================== class Collator: def __init__(self, tokenizer, max_length=512): self.tokenizer = tokenizer self.max_length = max_length def __call__(self, batch): if not batch: return None texts = [x["__lc_text"] for x in batch] enc = self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) return {"enc": enc, "raw": batch} # =========================== # Main # =========================== def main(): parser = argparse.ArgumentParser("Language Classifier Inference") parser.add_argument("--input_path", required=True) parser.add_argument("--output_path", required=True) parser.add_argument("--text_key", required=True) parser.add_argument("--batch_size", type=int, default=2048) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--num_workers", type=int, default=8) args = parser.parse_args() setup_logging(args.output_path) # -------------------- # DDP # -------------------- distributed, rank, world_size, local_rank = setup_distributed() device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" logging.info(f"Distributed={distributed} | World size={world_size}") # -------------------- # Load label map # -------------------- if not os.path.isfile(LABEL_MAP_PATH): raise RuntimeError(f"Missing label map: {LABEL_MAP_PATH}") with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f: label_map = json.load(f) id_to_label = {v: k for k, v in label_map.items()} # -------------------- # Load model # -------------------- if not os.path.isdir(MODEL_PATH): raise RuntimeError(f"Model directory not found: {MODEL_PATH}") logging.info(f"Loading model from {MODEL_PATH}") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.to(device) model.eval() # -------------------- # Dataset & Loader # -------------------- dataset = JsonlIterableDataset( args.input_path, args.text_key, rank=rank, world_size=world_size, ) dataloader = DataLoader( dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=Collator(tokenizer, args.max_length), pin_memory=True, persistent_workers=True, prefetch_factor=4, ) # -------------------- # Accumulators # -------------------- outputs: Dict[int, List[dict]] = {k: [] for k in id_to_label.keys()} # -------------------- # Inference # -------------------- iterator = tqdm(dataloader, desc="Classifying") if is_main_process() else dataloader with torch.no_grad(): for batch in iterator: if batch is None: continue try: enc = {k: v.to(device) for k, v in batch["enc"].items()} raw = batch["raw"] logits = model(**enc).logits preds = torch.argmax(logits, dim=-1).cpu().tolist() for obj, pred in zip(raw, preds): obj = dict(obj) obj.pop("__lc_text", None) obj["predicted_id"] = pred obj["predicted_language"] = id_to_label[pred] outputs[pred].append(obj) except Exception as e: logging.exception(f"Batch failed: {e}") # -------------------- # Write outputs # -------------------- os.makedirs(args.output_path, exist_ok=True) for cls_id, cls_name in id_to_label.items(): out_path = os.path.join( args.output_path, f"{cls_name}.rank{rank}.jsonl" ) logging.info(f"Writing {len(outputs[cls_id])} samples to {out_path}") with open(out_path, "w", encoding="utf-8") as f: for obj in outputs[cls_id]: f.write(json.dumps(obj, ensure_ascii=False) + "\n") if distributed: dist.barrier() dist.destroy_process_group() logging.info("Language classification completed successfully.") if __name__ == "__main__": main()