| """infer.py — load exported HF-style directory and run NER inference. |
| |
| Usage: |
| python infer.py --model_dir checkpoints/hf_export --text "..." |
| |
| Notes: |
| - This repo exports a lightweight HF-style folder: |
| config.json |
| pytorch_model.bin |
| tokenizer files (via transformers AutoTokenizer.save_pretrained) |
| - The model class is local (EcomBertNER in model.py). |
| """ |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| from model import EcomBertNER |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Inference with exported HF-style NER model") |
| p.add_argument("--model_dir", type=str, required=True, help="Path to HF export dir") |
| p.add_argument("--text", type=str, required=True, help="Input text") |
| p.add_argument("--max_length", type=int, default=256) |
| p.add_argument("--threshold", type=float, default=None, help="Override threshold (default: config.json or 0.5)") |
| p.add_argument("--device", type=str, default=None, help="cuda / cpu; default auto") |
| p.add_argument("--cache_dir", type=str, default=None) |
| return p.parse_args() |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| args = parse_args() |
|
|
| model_dir = Path(args.model_dir) |
| if not model_dir.exists(): |
| raise FileNotFoundError(f"model_dir not found: {model_dir}") |
|
|
| if args.device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
|
|
| model, cfg = EcomBertNER.from_pretrained(model_dir, device=device, cache_dir=args.cache_dir) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=args.cache_dir) |
|
|
| threshold = args.threshold |
| if threshold is None: |
| threshold = float(cfg.get("threshold", 0.5)) |
|
|
| enc = tokenizer( |
| args.text, |
| max_length=args.max_length, |
| truncation=True, |
| padding=False, |
| return_tensors="pt", |
| return_offsets_mapping=True, |
| ) |
|
|
| input_ids = enc["input_ids"].to(device) |
| attention_mask = enc["attention_mask"].to(device) |
| offsets = enc["offset_mapping"][0].tolist() |
|
|
| out = model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = out["logits"][0] |
| probs = torch.sigmoid(logits) |
|
|
| label_list = cfg.get("label_list") |
| if not label_list: |
| label_list = [str(i) for i in range(int(cfg.get("num_labels", probs.size(0))))] |
|
|
| hits = (probs > threshold).nonzero(as_tuple=False) |
|
|
| results = [] |
| for c, s, e in hits.tolist(): |
| if s >= len(offsets) or e >= len(offsets): |
| continue |
| char_s = offsets[s][0] |
| char_e = offsets[e][1] |
| if char_s == char_e == 0: |
| continue |
| if char_s < 0 or char_e <= char_s: |
| continue |
| ent_text = args.text[char_s:char_e] |
| results.append({ |
| "label": label_list[c] if c < len(label_list) else str(c), |
| "span": [char_s, char_e], |
| "text": ent_text, |
| "score": float(probs[c, s, e].item()), |
| }) |
|
|
| results.sort(key=lambda x: (-x["score"], x["span"][0], x["span"][1])) |
|
|
| print(f"device={device} threshold={threshold}") |
| for r in results: |
| print(f"{r['label']}: {r['text']} span={r['span']} score={r['score']:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|