""" CommonLingua — byte-level language identification. Quick start (download the model from HF first, e.g. `huggingface-cli download PleIAs/CommonLingua --local-dir ./CommonLingua`): python predict.py "Wikipédia est une encyclopédie universelle, multilingue." python predict.py --file paragraph.txt cat lines.txt | python predict.py --stdin python predict.py --dtype bf16 "..." # smaller, ~2x faster, same quality Or use as a library: from predict import load, predict model, idx2lang, max_len, device = load("model.pt") predict(model, ["..."], idx2lang, max_len, device) """ import argparse import sys from pathlib import Path import numpy as np import torch sys.path.insert(0, str(Path(__file__).resolve().parent)) from model import ByteHybrid, CONFIGS # noqa: E402 def load(checkpoint, dtype="fp32", device=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"], **CONFIGS[ckpt["config"]]) model.load_state_dict(ckpt["model_state_dict"]) if dtype == "bf16": model = model.to(torch.bfloat16) model.eval().to(device) idx2lang = {v: k for k, v in ckpt["lang2idx"].items()} return model, idx2lang, ckpt["max_len"], device def encode(texts, max_len): out = np.full((len(texts), max_len), 256, dtype=np.int64) for i, t in enumerate(texts): if not isinstance(t, str): t = "" if t is None else str(t) raw = t.encode("utf-8", errors="replace")[:max_len] if raw: out[i, :len(raw)] = np.frombuffer(raw, dtype=np.uint8) return torch.from_numpy(out) @torch.no_grad() def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256): """Returns a list of [(lang, prob), ...] (one list per text, top-k entries each).""" out = [] for i in range(0, len(texts), batch_size): b = encode(texts[i:i + batch_size], max_len).to(device) probs = torch.softmax(model(b).float(), dim=-1) top_p, top_idx = probs.topk(top_k, dim=-1) for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()): out.append([(idx2lang[j], float(p)) for p, j in zip(p_row, idx_row)]) return out def _main(): p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0], formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("text", nargs="*", help="Text(s) to classify (one per arg).") p.add_argument("--file", help="Read a single text from FILE.") p.add_argument("--stdin", action="store_true", help="One text per line from stdin.") p.add_argument("--checkpoint", default=str(Path(__file__).resolve().parent / "model.pt")) p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32") p.add_argument("--device", default=None) p.add_argument("--top-k", type=int, default=3) p.add_argument("--batch-size", type=int, default=256) args = p.parse_args() if args.stdin: texts = [line.rstrip("\n") for line in sys.stdin if line.strip()] elif args.file: texts = [Path(args.file).read_text(encoding="utf-8")] elif args.text: texts = args.text else: p.print_help() return model, idx2lang, max_len, device = load(args.checkpoint, args.dtype, args.device) print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}", file=sys.stderr) for text, top in zip(texts, predict(model, texts, idx2lang, max_len, device, args.top_k, args.batch_size)): preview = text[:80].replace("\n", " ") others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:]) print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}") if __name__ == "__main__": _main()