Language-Classifier / lc_infer.py
atharv-savarkar's picture
Upload folder using huggingface_hub
fc2701b verified
import os
import json
import argparse
import logging
from tqdm import tqdm
from typing import List, Dict
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ===========================
# PATH RESOLUTION (NO HARDCODE)
# ===========================
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(SCRIPT_DIR, "model")
LABEL_MAP_PATH = os.path.join(SCRIPT_DIR, "label_to_id.json")
# ===========================
# Logging
# ===========================
def setup_logging(output_dir):
os.makedirs(output_dir, exist_ok=True)
log_path = os.path.join(output_dir, "language_classifier.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
handlers=[
logging.FileHandler(log_path),
logging.StreamHandler()
],
)
logging.info(f"Logging to: {log_path}")
# ===========================
# DDP SETUP
# ===========================
def setup_distributed():
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group(backend="nccl")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
return True, rank, world_size, local_rank
return False, 0, 1, 0
def is_main_process():
return (
not dist.is_available()
or not dist.is_initialized()
or dist.get_rank() == 0
)
# ===========================
# Input Discovery
# ===========================
def find_all_jsonl_files(path: str) -> List[str]:
if os.path.isfile(path):
if not path.endswith(".jsonl"):
raise ValueError(f"Input file must be .jsonl: {path}")
return [path]
if not os.path.isdir(path):
raise ValueError(f"Input path does not exist: {path}")
files = []
for root, _, filenames in os.walk(path):
for fn in filenames:
if fn.endswith(".jsonl"):
files.append(os.path.join(root, fn))
if not files:
raise RuntimeError(f"No .jsonl files found inside: {path}")
return sorted(files)
# ===========================
# Dataset (Streaming, DDP-safe)
# ===========================
class JsonlIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, input_path: str, text_key: str, rank: int, world_size: int):
self.files = find_all_jsonl_files(input_path)
self.text_key = text_key
self.rank = rank
self.world_size = world_size
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1
global_worker_id = self.rank * num_workers + worker_id
global_num_workers = self.world_size * num_workers
json_loads = json.loads
text_key = self.text_key
for path in self.files:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
i = 0
for line in f:
if i == global_worker_id:
try:
obj = json_loads(line)
except json.JSONDecodeError:
pass
else:
text = obj.get(text_key)
if isinstance(text, str) and text.strip():
obj["__lc_text"] = text
yield obj
i += 1
if i == global_num_workers:
i = 0
# ===========================
# Collator
# ===========================
class Collator:
def __init__(self, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, batch):
if not batch:
return None
texts = [x["__lc_text"] for x in batch]
enc = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
return {"enc": enc, "raw": batch}
# ===========================
# Main
# ===========================
def main():
parser = argparse.ArgumentParser("Language Classifier Inference")
parser.add_argument("--input_path", required=True)
parser.add_argument("--output_path", required=True)
parser.add_argument("--text_key", required=True)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
setup_logging(args.output_path)
# --------------------
# DDP
# --------------------
distributed, rank, world_size, local_rank = setup_distributed()
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
logging.info(f"Distributed={distributed} | World size={world_size}")
# --------------------
# Load label map
# --------------------
if not os.path.isfile(LABEL_MAP_PATH):
raise RuntimeError(f"Missing label map: {LABEL_MAP_PATH}")
with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f:
label_map = json.load(f)
id_to_label = {v: k for k, v in label_map.items()}
# --------------------
# Load model
# --------------------
if not os.path.isdir(MODEL_PATH):
raise RuntimeError(f"Model directory not found: {MODEL_PATH}")
logging.info(f"Loading model from {MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
# --------------------
# Dataset & Loader
# --------------------
dataset = JsonlIterableDataset(
args.input_path,
args.text_key,
rank=rank,
world_size=world_size,
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=Collator(tokenizer, args.max_length),
pin_memory=True,
persistent_workers=True,
prefetch_factor=4,
)
# --------------------
# Accumulators
# --------------------
outputs: Dict[int, List[dict]] = {k: [] for k in id_to_label.keys()}
# --------------------
# Inference
# --------------------
iterator = tqdm(dataloader, desc="Classifying") if is_main_process() else dataloader
with torch.no_grad():
for batch in iterator:
if batch is None:
continue
try:
enc = {k: v.to(device) for k, v in batch["enc"].items()}
raw = batch["raw"]
logits = model(**enc).logits
preds = torch.argmax(logits, dim=-1).cpu().tolist()
for obj, pred in zip(raw, preds):
obj = dict(obj)
obj.pop("__lc_text", None)
obj["predicted_id"] = pred
obj["predicted_language"] = id_to_label[pred]
outputs[pred].append(obj)
except Exception as e:
logging.exception(f"Batch failed: {e}")
# --------------------
# Write outputs
# --------------------
os.makedirs(args.output_path, exist_ok=True)
for cls_id, cls_name in id_to_label.items():
out_path = os.path.join(
args.output_path,
f"{cls_name}.rank{rank}.jsonl"
)
logging.info(f"Writing {len(outputs[cls_id])} samples to {out_path}")
with open(out_path, "w", encoding="utf-8") as f:
for obj in outputs[cls_id]:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
if distributed:
dist.barrier()
dist.destroy_process_group()
logging.info("Language classification completed successfully.")
if __name__ == "__main__":
main()