ecombert-ner-v1 / infer.py
xinyacs's picture
Upload folder using huggingface_hub
7781e94 verified
"""infer.py — load exported HF-style directory and run NER inference.
Usage:
python infer.py --model_dir checkpoints/hf_export --text "..."
Notes:
- This repo exports a lightweight HF-style folder:
config.json
pytorch_model.bin
tokenizer files (via transformers AutoTokenizer.save_pretrained)
- The model class is local (EcomBertNER in model.py).
"""
import argparse
from pathlib import Path
import torch
from transformers import AutoTokenizer
from model import EcomBertNER
def parse_args():
p = argparse.ArgumentParser(description="Inference with exported HF-style NER model")
p.add_argument("--model_dir", type=str, required=True, help="Path to HF export dir")
p.add_argument("--text", type=str, required=True, help="Input text")
p.add_argument("--max_length", type=int, default=256)
p.add_argument("--threshold", type=float, default=None, help="Override threshold (default: config.json or 0.5)")
p.add_argument("--device", type=str, default=None, help="cuda / cpu; default auto")
p.add_argument("--cache_dir", type=str, default=None)
return p.parse_args()
@torch.no_grad()
def main():
args = parse_args()
model_dir = Path(args.model_dir)
if not model_dir.exists():
raise FileNotFoundError(f"model_dir not found: {model_dir}")
if args.device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
model, cfg = EcomBertNER.from_pretrained(model_dir, device=device, cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=args.cache_dir)
threshold = args.threshold
if threshold is None:
threshold = float(cfg.get("threshold", 0.5))
enc = tokenizer(
args.text,
max_length=args.max_length,
truncation=True,
padding=False,
return_tensors="pt",
return_offsets_mapping=True,
)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)
offsets = enc["offset_mapping"][0].tolist()
out = model(input_ids=input_ids, attention_mask=attention_mask)
logits = out["logits"][0] # (C, L, L)
probs = torch.sigmoid(logits)
label_list = cfg.get("label_list")
if not label_list:
label_list = [str(i) for i in range(int(cfg.get("num_labels", probs.size(0))))]
hits = (probs > threshold).nonzero(as_tuple=False)
results = []
for c, s, e in hits.tolist():
if s >= len(offsets) or e >= len(offsets):
continue
char_s = offsets[s][0]
char_e = offsets[e][1]
if char_s == char_e == 0:
continue
if char_s < 0 or char_e <= char_s:
continue
ent_text = args.text[char_s:char_e]
results.append({
"label": label_list[c] if c < len(label_list) else str(c),
"span": [char_s, char_e],
"text": ent_text,
"score": float(probs[c, s, e].item()),
})
results.sort(key=lambda x: (-x["score"], x["span"][0], x["span"][1]))
print(f"device={device} threshold={threshold}")
for r in results:
print(f"{r['label']}: {r['text']} span={r['span']} score={r['score']:.4f}")
if __name__ == "__main__":
main()