| """ |
| CommonLingua — byte-level language identification. |
| |
| Quick start (download the model from HF first, e.g. `huggingface-cli download |
| PleIAs/CommonLingua --local-dir ./CommonLingua`): |
| |
| python predict.py "Wikipédia est une encyclopédie universelle, multilingue." |
| python predict.py --file paragraph.txt |
| cat lines.txt | python predict.py --stdin |
| python predict.py --dtype bf16 "..." # smaller, ~2x faster, same quality |
| |
| Or use as a library: |
| |
| from predict import load, predict |
| model, idx2lang, max_len, device = load("model.pt") |
| predict(model, ["..."], idx2lang, max_len, device) |
| """ |
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
| from model import ByteHybrid, CONFIGS |
|
|
|
|
| def load(checkpoint, dtype="fp32", device=None): |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(device) |
| ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) |
| model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"], |
| **CONFIGS[ckpt["config"]]) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| if dtype == "bf16": |
| model = model.to(torch.bfloat16) |
| model.eval().to(device) |
| idx2lang = {v: k for k, v in ckpt["lang2idx"].items()} |
| return model, idx2lang, ckpt["max_len"], device |
|
|
|
|
| def encode(texts, max_len): |
| out = np.full((len(texts), max_len), 256, dtype=np.int64) |
| for i, t in enumerate(texts): |
| if not isinstance(t, str): |
| t = "" if t is None else str(t) |
| raw = t.encode("utf-8", errors="replace")[:max_len] |
| if raw: |
| out[i, :len(raw)] = np.frombuffer(raw, dtype=np.uint8) |
| return torch.from_numpy(out) |
|
|
|
|
| @torch.no_grad() |
| def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256): |
| """Returns a list of [(lang, prob), ...] (one list per text, top-k entries each).""" |
| out = [] |
| for i in range(0, len(texts), batch_size): |
| b = encode(texts[i:i + batch_size], max_len).to(device) |
| probs = torch.softmax(model(b).float(), dim=-1) |
| top_p, top_idx = probs.topk(top_k, dim=-1) |
| for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()): |
| out.append([(idx2lang[j], float(p)) for p, j in zip(p_row, idx_row)]) |
| return out |
|
|
|
|
| def _main(): |
| p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0], |
| formatter_class=argparse.RawDescriptionHelpFormatter) |
| p.add_argument("text", nargs="*", help="Text(s) to classify (one per arg).") |
| p.add_argument("--file", help="Read a single text from FILE.") |
| p.add_argument("--stdin", action="store_true", help="One text per line from stdin.") |
| p.add_argument("--checkpoint", default=str(Path(__file__).resolve().parent / "model.pt")) |
| p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32") |
| p.add_argument("--device", default=None) |
| p.add_argument("--top-k", type=int, default=3) |
| p.add_argument("--batch-size", type=int, default=256) |
| args = p.parse_args() |
|
|
| if args.stdin: |
| texts = [line.rstrip("\n") for line in sys.stdin if line.strip()] |
| elif args.file: |
| texts = [Path(args.file).read_text(encoding="utf-8")] |
| elif args.text: |
| texts = args.text |
| else: |
| p.print_help() |
| return |
|
|
| model, idx2lang, max_len, device = load(args.checkpoint, args.dtype, args.device) |
| print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}", |
| file=sys.stderr) |
| for text, top in zip(texts, predict(model, texts, idx2lang, max_len, device, |
| args.top_k, args.batch_size)): |
| preview = text[:80].replace("\n", " ") |
| others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:]) |
| print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}") |
|
|
|
|
| if __name__ == "__main__": |
| _main() |
|
|