# Inference helper (use with model dir or Hub repo) import json, torch from transformers import AutoTokenizer, AutoModelForSequenceClassification SPECIAL_TOKENS = ["", ""] def load(model_or_repo: str): tok = AutoTokenizer.from_pretrained(model_or_repo, use_fast=True) mdl = AutoModelForSequenceClassification.from_pretrained(model_or_repo) return mdl, tok @torch.no_grad() def classify_marked(model, tokenizer, marked_text: str): enc = tokenizer(marked_text, return_tensors="pt", truncation=True) out = model(**enc) probs = out.logits.softmax(-1).squeeze(0).tolist() return {"label": "dm" if probs[1] > probs[0] else "not_dm", "prob_dm": probs[1], "probs": {"not_dm": probs[0], "dm": probs[1]}} def detect_candidates(text: str, gazetteer): spans = [] used = [False]*len(text) for cand in sorted(gazetteer, key=lambda s: (-len(s), s)): start = 0 while True: i = text.find(cand, start) if i == -1: break j = i + len(cand) if not any(used[i:j]): spans.append((i, j, cand)) for k in range(i, j): used[k] = True start = j else: start = i + 1 spans.sort(key=lambda x: x[0]) return spans def mark_first(text: str, cand: str): return text.replace(cand, f"{SPECIAL_TOKENS[0]} {cand} {SPECIAL_TOKENS[1]}", 1) def load_gazetteer(model_or_repo: str): try: with open(model_or_repo + "/assets/gazetteer.json", "r", encoding="utf-8") as f: return json.load(f)["items"] except Exception: from huggingface_hub import hf_hub_download p = hf_hub_download(repo_id=model_or_repo, filename="assets/gazetteer.json") return json.load(open(p, "r", encoding="utf-8"))["items"]