temsa's picture
Publish rc7 with spec-driven scanner release
32bcb86 verified
#!/usr/bin/env python3
import re
import torch
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def ppsn_label_ids(model) -> list[int]:
ids = []
for raw_id, raw_label in model.config.id2label.items():
label_id = int(raw_id)
label = str(raw_label or "").strip()
if label.endswith("PPSN"):
ids.append(label_id)
return sorted(ids)
def word_aligned_ppsn_spans(text: str, model, tokenizer, threshold: float) -> list[dict]:
pieces = tokenize_with_spans(text)
if not pieces:
return []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
device = next(model.parameters()).device
encoded = {k: v.to(device) for k, v in encoded.items()}
with torch.no_grad():
logits = model(**encoded).logits[0]
probs = torch.softmax(logits, dim=-1)
label_ids = ppsn_label_ids(model)
word_scores: list[float] = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in label_ids:
score = max(score, float(probs[token_index, label_id]))
word_scores.append(score)
spans: list[dict] = []
active = None
for (word, start, end), score in zip(pieces, word_scores):
if score >= threshold:
if active is None:
active = {"start": start, "end": end, "score": score}
else:
active["end"] = end
active["score"] = max(active["score"], score)
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
for span in spans:
span["text"] = text[span["start"] : span["end"]]
span["label"] = "PPSN"
span["source"] = "model"
return spans