File size: 4,027 Bytes
c6cd15e
 
d988f49
c6cd15e
 
d988f49
c6cd15e
 
 
 
d988f49
c6cd15e
d988f49
c6cd15e
 
 
d988f49
 
 
 
 
 
 
 
c6cd15e
d988f49
 
 
 
 
 
 
 
c6cd15e
 
d988f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6cd15e
d988f49
 
c6cd15e
d988f49
 
 
 
 
 
 
c6cd15e
 
 
 
d988f49
 
c6cd15e
d988f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6cd15e
d988f49
 
c6cd15e
 
d988f49
 
 
 
 
 
c6cd15e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
CommonLingua — byte-level language identification.

Quick start (download the model from HF first, e.g. `huggingface-cli download
PleIAs/CommonLingua --local-dir ./CommonLingua`):

    python predict.py "Wikipédia est une encyclopédie universelle, multilingue."
    python predict.py --file paragraph.txt
    cat lines.txt | python predict.py --stdin
    python predict.py --dtype bf16 "..."   # smaller, ~2x faster, same quality

Or use as a library:

    from predict import load, predict
    model, idx2lang, max_len, device = load("model.pt")
    predict(model, ["..."], idx2lang, max_len, device)
"""
import argparse
import sys
from pathlib import Path

import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).resolve().parent))
from model import ByteHybrid, CONFIGS  # noqa: E402


def load(checkpoint, dtype="fp32", device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
    model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"],
                       **CONFIGS[ckpt["config"]])
    model.load_state_dict(ckpt["model_state_dict"])
    if dtype == "bf16":
        model = model.to(torch.bfloat16)
    model.eval().to(device)
    idx2lang = {v: k for k, v in ckpt["lang2idx"].items()}
    return model, idx2lang, ckpt["max_len"], device


def encode(texts, max_len):
    out = np.full((len(texts), max_len), 256, dtype=np.int64)
    for i, t in enumerate(texts):
        if not isinstance(t, str):
            t = "" if t is None else str(t)
        raw = t.encode("utf-8", errors="replace")[:max_len]
        if raw:
            out[i, :len(raw)] = np.frombuffer(raw, dtype=np.uint8)
    return torch.from_numpy(out)


@torch.no_grad()
def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
    """Returns a list of [(lang, prob), ...] (one list per text, top-k entries each)."""
    out = []
    for i in range(0, len(texts), batch_size):
        b = encode(texts[i:i + batch_size], max_len).to(device)
        probs = torch.softmax(model(b).float(), dim=-1)
        top_p, top_idx = probs.topk(top_k, dim=-1)
        for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()):
            out.append([(idx2lang[j], float(p)) for p, j in zip(p_row, idx_row)])
    return out


def _main():
    p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0],
                                formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("text", nargs="*", help="Text(s) to classify (one per arg).")
    p.add_argument("--file", help="Read a single text from FILE.")
    p.add_argument("--stdin", action="store_true", help="One text per line from stdin.")
    p.add_argument("--checkpoint", default=str(Path(__file__).resolve().parent / "model.pt"))
    p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32")
    p.add_argument("--device", default=None)
    p.add_argument("--top-k", type=int, default=3)
    p.add_argument("--batch-size", type=int, default=256)
    args = p.parse_args()

    if args.stdin:
        texts = [line.rstrip("\n") for line in sys.stdin if line.strip()]
    elif args.file:
        texts = [Path(args.file).read_text(encoding="utf-8")]
    elif args.text:
        texts = args.text
    else:
        p.print_help()
        return

    model, idx2lang, max_len, device = load(args.checkpoint, args.dtype, args.device)
    print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}",
          file=sys.stderr)
    for text, top in zip(texts, predict(model, texts, idx2lang, max_len, device,
                                        args.top_k, args.batch_size)):
        preview = text[:80].replace("\n", " ")
        others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:])
        print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}")


if __name__ == "__main__":
    _main()