Pclanglais commited on
Commit
c6cd15e
·
verified ·
1 Parent(s): de97637

simplify predict.py: cleaner standalone CLI + library entry points

Browse files
Files changed (1) hide show
  1. predict.py +26 -20
predict.py CHANGED
@@ -1,14 +1,19 @@
1
- """Standalone CommonLingua predict — single file, no `commonlingua` package required.
 
2
 
3
- Drop this next to `model.py` and the checkpoint, then:
 
4
 
5
- python predict.py "Wikipedia is a free online encyclopedia, ..."
6
- python predict.py --file input.txt
7
- cat texts.tsv | python predict.py --stdin
 
8
 
9
- For the full Python API and parquet batch mode, install the package:
10
 
11
- pip install "git+https://github.com/PleIAs/bytehybrid-lid#egg=commonlingua[hub]"
 
 
12
  """
13
  import argparse
14
  import sys
@@ -17,7 +22,7 @@ from pathlib import Path
17
  import numpy as np
18
  import torch
19
 
20
- sys.path.insert(0, str(Path(__file__).parent))
21
  from model import ByteHybrid, CONFIGS # noqa: E402
22
 
23
 
@@ -26,8 +31,8 @@ def load(checkpoint, dtype="fp32", device=None):
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  device = torch.device(device)
28
  ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
29
- cfg = CONFIGS[ckpt["config"]]
30
- model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"], **cfg)
31
  model.load_state_dict(ckpt["model_state_dict"])
32
  if dtype == "bf16":
33
  model = model.to(torch.bfloat16)
@@ -49,10 +54,10 @@ def encode(texts, max_len):
49
 
50
  @torch.no_grad()
51
  def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
 
52
  out = []
53
  for i in range(0, len(texts), batch_size):
54
- chunk = texts[i:i + batch_size]
55
- b = encode(chunk, max_len).to(device)
56
  probs = torch.softmax(model(b).float(), dim=-1)
57
  top_p, top_idx = probs.topk(top_k, dim=-1)
58
  for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()):
@@ -60,12 +65,13 @@ def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
60
  return out
61
 
62
 
63
- def main():
64
- p = argparse.ArgumentParser()
65
- p.add_argument("text", nargs="*", help="Texts to classify (one per arg).")
 
66
  p.add_argument("--file", help="Read a single text from FILE.")
67
  p.add_argument("--stdin", action="store_true", help="One text per line from stdin.")
68
- p.add_argument("--checkpoint", default=str(Path(__file__).parent / "model.pt"))
69
  p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32")
70
  p.add_argument("--device", default=None)
71
  p.add_argument("--top-k", type=int, default=3)
@@ -82,15 +88,15 @@ def main():
82
  p.print_help()
83
  return
84
 
85
- model, idx2lang, max_len, device = load(args.checkpoint, dtype=args.dtype, device=args.device)
86
  print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}",
87
  file=sys.stderr)
88
- results = predict(model, texts, idx2lang, max_len, device, args.top_k, args.batch_size)
89
- for text, top in zip(texts, results):
90
  preview = text[:80].replace("\n", " ")
91
  others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:])
92
  print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}")
93
 
94
 
95
  if __name__ == "__main__":
96
- main()
 
1
+ """
2
+ CommonLingua — byte-level language identification.
3
 
4
+ Quick start (download the model from HF first, e.g. `huggingface-cli download
5
+ PleIAs/CommonLingua --local-dir ./CommonLingua`):
6
 
7
+ python predict.py "Wikipédia est une encyclopédie universelle, multilingue."
8
+ python predict.py --file paragraph.txt
9
+ cat lines.txt | python predict.py --stdin
10
+ python predict.py --dtype bf16 "..." # smaller, ~2x faster, same quality
11
 
12
+ Or use as a library:
13
 
14
+ from predict import load, predict
15
+ model, idx2lang, max_len, device = load("model.pt")
16
+ predict(model, ["..."], idx2lang, max_len, device)
17
  """
18
  import argparse
19
  import sys
 
22
  import numpy as np
23
  import torch
24
 
25
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
26
  from model import ByteHybrid, CONFIGS # noqa: E402
27
 
28
 
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  device = torch.device(device)
33
  ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
34
+ model = ByteHybrid(num_classes=ckpt["num_classes"], max_len=ckpt["max_len"],
35
+ **CONFIGS[ckpt["config"]])
36
  model.load_state_dict(ckpt["model_state_dict"])
37
  if dtype == "bf16":
38
  model = model.to(torch.bfloat16)
 
54
 
55
  @torch.no_grad()
56
  def predict(model, texts, idx2lang, max_len, device, top_k=3, batch_size=256):
57
+ """Returns a list of [(lang, prob), ...] (one list per text, top-k entries each)."""
58
  out = []
59
  for i in range(0, len(texts), batch_size):
60
+ b = encode(texts[i:i + batch_size], max_len).to(device)
 
61
  probs = torch.softmax(model(b).float(), dim=-1)
62
  top_p, top_idx = probs.topk(top_k, dim=-1)
63
  for p_row, idx_row in zip(top_p.cpu().tolist(), top_idx.cpu().tolist()):
 
65
  return out
66
 
67
 
68
+ def _main():
69
+ p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0],
70
+ formatter_class=argparse.RawDescriptionHelpFormatter)
71
+ p.add_argument("text", nargs="*", help="Text(s) to classify (one per arg).")
72
  p.add_argument("--file", help="Read a single text from FILE.")
73
  p.add_argument("--stdin", action="store_true", help="One text per line from stdin.")
74
+ p.add_argument("--checkpoint", default=str(Path(__file__).resolve().parent / "model.pt"))
75
  p.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32")
76
  p.add_argument("--device", default=None)
77
  p.add_argument("--top-k", type=int, default=3)
 
88
  p.print_help()
89
  return
90
 
91
+ model, idx2lang, max_len, device = load(args.checkpoint, args.dtype, args.device)
92
  print(f"# {len(idx2lang)} languages, max_len={max_len}, dtype={args.dtype}, device={device}",
93
  file=sys.stderr)
94
+ for text, top in zip(texts, predict(model, texts, idx2lang, max_len, device,
95
+ args.top_k, args.batch_size)):
96
  preview = text[:80].replace("\n", " ")
97
  others = " ".join(f"{lg}={p:.3f}" for lg, p in top[1:])
98
  print(f"{top[0][0]}\t{top[0][1]:.4f}\t{others}\t{preview}")
99
 
100
 
101
  if __name__ == "__main__":
102
+ _main()