File size: 1,859 Bytes
2572771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Inference helper (use with model dir or Hub repo)
import json, torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
SPECIAL_TOKENS = ["<cand>", "</cand>"]

def load(model_or_repo: str):
    tok = AutoTokenizer.from_pretrained(model_or_repo, use_fast=True)
    mdl = AutoModelForSequenceClassification.from_pretrained(model_or_repo)
    return mdl, tok

@torch.no_grad()
def classify_marked(model, tokenizer, marked_text: str):
    enc = tokenizer(marked_text, return_tensors="pt", truncation=True)
    out = model(**enc)
    probs = out.logits.softmax(-1).squeeze(0).tolist()
    return {"label": "dm" if probs[1] > probs[0] else "not_dm", "prob_dm": probs[1], "probs": {"not_dm": probs[0], "dm": probs[1]}}

def detect_candidates(text: str, gazetteer):
    spans = []
    used = [False]*len(text)
    for cand in sorted(gazetteer, key=lambda s: (-len(s), s)):
        start = 0
        while True:
            i = text.find(cand, start)
            if i == -1: break
            j = i + len(cand)
            if not any(used[i:j]):
                spans.append((i, j, cand))
                for k in range(i, j): used[k] = True
                start = j
            else:
                start = i + 1
    spans.sort(key=lambda x: x[0])
    return spans

def mark_first(text: str, cand: str):
    return text.replace(cand, f"{SPECIAL_TOKENS[0]} {cand} {SPECIAL_TOKENS[1]}", 1)

def load_gazetteer(model_or_repo: str):
    try:
        with open(model_or_repo + "/assets/gazetteer.json", "r", encoding="utf-8") as f:
            return json.load(f)["items"]
    except Exception:
        from huggingface_hub import hf_hub_download
        p = hf_hub_download(repo_id=model_or_repo, filename="assets/gazetteer.json")
        return json.load(open(p, "r", encoding="utf-8"))["items"]