| | import os |
| | import json |
| | import argparse |
| | import logging |
| | from tqdm import tqdm |
| | from typing import List, Dict |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.utils.data import DataLoader |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | MODEL_PATH = os.path.join(SCRIPT_DIR, "model") |
| | LABEL_MAP_PATH = os.path.join(SCRIPT_DIR, "label_to_id.json") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def setup_logging(output_dir): |
| | os.makedirs(output_dir, exist_ok=True) |
| | log_path = os.path.join(output_dir, "language_classifier.log") |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s | %(levelname)s | %(message)s", |
| | handlers=[ |
| | logging.FileHandler(log_path), |
| | logging.StreamHandler() |
| | ], |
| | ) |
| |
|
| | logging.info(f"Logging to: {log_path}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def setup_distributed(): |
| | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| | dist.init_process_group(backend="nccl") |
| | rank = int(os.environ["RANK"]) |
| | world_size = int(os.environ["WORLD_SIZE"]) |
| | local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| | torch.cuda.set_device(local_rank) |
| | return True, rank, world_size, local_rank |
| | return False, 0, 1, 0 |
| |
|
| |
|
| | def is_main_process(): |
| | return ( |
| | not dist.is_available() |
| | or not dist.is_initialized() |
| | or dist.get_rank() == 0 |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def find_all_jsonl_files(path: str) -> List[str]: |
| | if os.path.isfile(path): |
| | if not path.endswith(".jsonl"): |
| | raise ValueError(f"Input file must be .jsonl: {path}") |
| | return [path] |
| |
|
| | if not os.path.isdir(path): |
| | raise ValueError(f"Input path does not exist: {path}") |
| |
|
| | files = [] |
| | for root, _, filenames in os.walk(path): |
| | for fn in filenames: |
| | if fn.endswith(".jsonl"): |
| | files.append(os.path.join(root, fn)) |
| |
|
| | if not files: |
| | raise RuntimeError(f"No .jsonl files found inside: {path}") |
| |
|
| | return sorted(files) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class JsonlIterableDataset(torch.utils.data.IterableDataset): |
| | def __init__(self, input_path: str, text_key: str, rank: int, world_size: int): |
| | self.files = find_all_jsonl_files(input_path) |
| | self.text_key = text_key |
| | self.rank = rank |
| | self.world_size = world_size |
| |
|
| | def __iter__(self): |
| | worker_info = torch.utils.data.get_worker_info() |
| | worker_id = worker_info.id if worker_info else 0 |
| | num_workers = worker_info.num_workers if worker_info else 1 |
| |
|
| | global_worker_id = self.rank * num_workers + worker_id |
| | global_num_workers = self.world_size * num_workers |
| |
|
| | json_loads = json.loads |
| | text_key = self.text_key |
| |
|
| | for path in self.files: |
| | with open(path, "r", encoding="utf-8", errors="ignore") as f: |
| | i = 0 |
| | for line in f: |
| | if i == global_worker_id: |
| | try: |
| | obj = json_loads(line) |
| | except json.JSONDecodeError: |
| | pass |
| | else: |
| | text = obj.get(text_key) |
| | if isinstance(text, str) and text.strip(): |
| | obj["__lc_text"] = text |
| | yield obj |
| |
|
| | i += 1 |
| | if i == global_num_workers: |
| | i = 0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Collator: |
| | def __init__(self, tokenizer, max_length=512): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __call__(self, batch): |
| | if not batch: |
| | return None |
| |
|
| | texts = [x["__lc_text"] for x in batch] |
| |
|
| | enc = self.tokenizer( |
| | texts, |
| | padding=True, |
| | truncation=True, |
| | max_length=self.max_length, |
| | return_tensors="pt", |
| | ) |
| |
|
| | return {"enc": enc, "raw": batch} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser("Language Classifier Inference") |
| |
|
| | parser.add_argument("--input_path", required=True) |
| | parser.add_argument("--output_path", required=True) |
| | parser.add_argument("--text_key", required=True) |
| |
|
| | parser.add_argument("--batch_size", type=int, default=2048) |
| | parser.add_argument("--max_length", type=int, default=512) |
| | parser.add_argument("--num_workers", type=int, default=8) |
| |
|
| | args = parser.parse_args() |
| |
|
| | setup_logging(args.output_path) |
| |
|
| | |
| | |
| | |
| | distributed, rank, world_size, local_rank = setup_distributed() |
| | device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" |
| |
|
| | logging.info(f"Distributed={distributed} | World size={world_size}") |
| |
|
| | |
| | |
| | |
| | if not os.path.isfile(LABEL_MAP_PATH): |
| | raise RuntimeError(f"Missing label map: {LABEL_MAP_PATH}") |
| |
|
| | with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f: |
| | label_map = json.load(f) |
| |
|
| | id_to_label = {v: k for k, v in label_map.items()} |
| |
|
| | |
| | |
| | |
| | if not os.path.isdir(MODEL_PATH): |
| | raise RuntimeError(f"Model directory not found: {MODEL_PATH}") |
| |
|
| | logging.info(f"Loading model from {MODEL_PATH}") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | |
| | |
| | dataset = JsonlIterableDataset( |
| | args.input_path, |
| | args.text_key, |
| | rank=rank, |
| | world_size=world_size, |
| | ) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers, |
| | collate_fn=Collator(tokenizer, args.max_length), |
| | pin_memory=True, |
| | persistent_workers=True, |
| | prefetch_factor=4, |
| | ) |
| |
|
| | |
| | |
| | |
| | outputs: Dict[int, List[dict]] = {k: [] for k in id_to_label.keys()} |
| |
|
| | |
| | |
| | |
| | iterator = tqdm(dataloader, desc="Classifying") if is_main_process() else dataloader |
| |
|
| | with torch.no_grad(): |
| | for batch in iterator: |
| | if batch is None: |
| | continue |
| |
|
| | try: |
| | enc = {k: v.to(device) for k, v in batch["enc"].items()} |
| | raw = batch["raw"] |
| |
|
| | logits = model(**enc).logits |
| | preds = torch.argmax(logits, dim=-1).cpu().tolist() |
| |
|
| | for obj, pred in zip(raw, preds): |
| | obj = dict(obj) |
| | obj.pop("__lc_text", None) |
| | obj["predicted_id"] = pred |
| | obj["predicted_language"] = id_to_label[pred] |
| | outputs[pred].append(obj) |
| |
|
| | except Exception as e: |
| | logging.exception(f"Batch failed: {e}") |
| |
|
| | |
| | |
| | |
| | os.makedirs(args.output_path, exist_ok=True) |
| |
|
| | for cls_id, cls_name in id_to_label.items(): |
| | out_path = os.path.join( |
| | args.output_path, |
| | f"{cls_name}.rank{rank}.jsonl" |
| | ) |
| |
|
| | logging.info(f"Writing {len(outputs[cls_id])} samples to {out_path}") |
| |
|
| | with open(out_path, "w", encoding="utf-8") as f: |
| | for obj in outputs[cls_id]: |
| | f.write(json.dumps(obj, ensure_ascii=False) + "\n") |
| |
|
| | if distributed: |
| | dist.barrier() |
| | dist.destroy_process_group() |
| |
|
| | logging.info("Language classification completed successfully.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|