| |
| import re |
| import torch |
|
|
|
|
| TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) |
|
|
|
|
| def tokenize_with_spans(text: str): |
| return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] |
|
|
|
|
| def ppsn_label_ids(model) -> list[int]: |
| ids = [] |
| for raw_id, raw_label in model.config.id2label.items(): |
| label_id = int(raw_id) |
| label = str(raw_label or "").strip() |
| if label.endswith("PPSN"): |
| ids.append(label_id) |
| return sorted(ids) |
|
|
|
|
| def word_aligned_ppsn_spans(text: str, model, tokenizer, threshold: float) -> list[dict]: |
| pieces = tokenize_with_spans(text) |
| if not pieces: |
| return [] |
|
|
| words = [word for word, _, _ in pieces] |
| encoded = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True) |
| word_ids = encoded.word_ids(batch_index=0) |
|
|
| device = next(model.parameters()).device |
| encoded = {k: v.to(device) for k, v in encoded.items()} |
|
|
| with torch.no_grad(): |
| logits = model(**encoded).logits[0] |
|
|
| probs = torch.softmax(logits, dim=-1) |
| label_ids = ppsn_label_ids(model) |
|
|
| word_scores: list[float] = [] |
| for word_index in range(len(pieces)): |
| score = 0.0 |
| for token_index, wid in enumerate(word_ids): |
| if wid != word_index: |
| continue |
| for label_id in label_ids: |
| score = max(score, float(probs[token_index, label_id])) |
| word_scores.append(score) |
|
|
| spans: list[dict] = [] |
| active = None |
| for (word, start, end), score in zip(pieces, word_scores): |
| if score >= threshold: |
| if active is None: |
| active = {"start": start, "end": end, "score": score} |
| else: |
| active["end"] = end |
| active["score"] = max(active["score"], score) |
| elif active is not None: |
| spans.append(active) |
| active = None |
|
|
| if active is not None: |
| spans.append(active) |
|
|
| for span in spans: |
| span["text"] = text[span["start"] : span["end"]] |
| span["label"] = "PPSN" |
| span["source"] = "model" |
|
|
| return spans |
|
|