CommonLingua / predict.py
Pclanglais's picture
simplify predict.py: cleaner standalone CLI + library entry points
c6cd15e verified
"""
CommonLingua — byte-level language identification.
Quick start (download the model from HF first, e.g. `huggingface-cli download
PleIAs/CommonLingua --local-dir ./CommonLingua`):
python predict.py "Wikipédia est une encyclopédie universelle, multilingue."
python predict.py --file paragraph.txt
cat lines.txt | python predict.py --stdin
python predict.py --dtype bf16 "..." # smaller, ~2x faster, same quality
Or use as a library:
from predict import load, predict
model, idx2lang, max_len, device = load("model.pt")
predict(model, ["..."], idx2lang, max_len, device)
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent))
from model import ByteHybrid, CONFIGS # noqa: E402
def load(checkpoint, dtype="fp32", device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"],
**CONFIGS[ckpt["config"]])
model.load_state_dict(ckpt["model_state_dict"])
if dtype == "bf16":
model = model.to(torch.bfloat16)
model.eval().to(device)
idx2lang = {v: k for k, v in ckpt["lang2idx"].items()}
return model, idx2lang, ckpt["max_len"], device
def encode(texts, max_len):
out = np.full((len(texts), max_len), 256, dtype=np.int64)
for i, t in enumerate(texts):
if not isinstance(t, str):
t = "" if t is None else str(t)
raw = t.encode("utf-8", errors="replace")[:max_len]
if raw:
out[i, :len(raw)] = np.frombuffer(raw, dtype=np.uint8)
return torch.from_numpy(out)
@torch.no_grad()
def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
"""Returns a list of [(lang, prob), ...] (one list per text, top-k entries each)."""
out = []
for i in range(0, len(texts), batch_size):
b = encode(texts[i:i + batch_size], max_len).to(device)
probs = torch.softmax(model(b).float(), dim=-1)
top_p, top_idx = probs.topk(top_k, dim=-1)
for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()):
out.append([(idx2lang[j], float(p)) for p, j in zip(p_row, idx_row)])
return out
def _main():
p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0],
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("text", nargs="*", help="Text(s) to classify (one per arg).")
p.add_argument("--file", help="Read a single text from FILE.")
p.add_argument("--stdin", action="store_true", help="One text per line from stdin.")
p.add_argument("--checkpoint", default=str(Path(__file__).resolve().parent / "model.pt"))
p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32")
p.add_argument("--device", default=None)
p.add_argument("--top-k", type=int, default=3)
p.add_argument("--batch-size", type=int, default=256)
args = p.parse_args()
if args.stdin:
texts = [line.rstrip("\n") for line in sys.stdin if line.strip()]
elif args.file:
texts = [Path(args.file).read_text(encoding="utf-8")]
elif args.text:
texts = args.text
else:
p.print_help()
return
model, idx2lang, max_len, device = load(args.checkpoint, args.dtype, args.device)
print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}",
file=sys.stderr)
for text, top in zip(texts, predict(model, texts, idx2lang, max_len, device,
args.top_k, args.batch_size)):
preview = text[:80].replace("\n", " ")
others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:])
print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}")
if __name__ == "__main__":
_main()