| from __future__ import annotations |
|
|
| import torch |
| from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
|
| from .schemas import Span |
|
|
| MODEL_ID = "lumicero/Joint-Uniform-BioNER" |
| DEVICE = "cpu" |
| BATCH_SIZE = 8 |
| MAX_LENGTH = 256 |
|
|
| _predictor: "NERPredictor | None" = None |
|
|
|
|
| class NERPredictor: |
| def __init__(self, model_id: str = MODEL_ID): |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| self.model = AutoModelForTokenClassification.from_pretrained(model_id) |
| self.model.to(DEVICE) |
| self.model.eval() |
| self.id2label = self.model.config.id2label |
|
|
| @torch.no_grad() |
| def predict(self, sentences: list[str]) -> list[list[Span]]: |
| results: list[list[Span]] = [] |
| for i in range(0, len(sentences), BATCH_SIZE): |
| batch = sentences[i : i + BATCH_SIZE] |
| results.extend(self._predict_batch(batch)) |
| return results |
|
|
| def _predict_batch(self, batch: list[str]) -> list[list[Span]]: |
| enc = self.tokenizer( |
| batch, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=MAX_LENGTH, |
| return_offsets_mapping=True, |
| ) |
| offsets_batch = enc.pop("offset_mapping").tolist() |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} |
| logits = self.model(**enc).logits |
| probs = torch.softmax(logits, dim=-1) |
| max_probs, pred_ids = probs.max(dim=-1) |
| max_probs = max_probs.cpu().tolist() |
| pred_ids = pred_ids.cpu().tolist() |
|
|
| out: list[list[Span]] = [] |
| for sent_idx, sentence in enumerate(batch): |
| offsets = offsets_batch[sent_idx] |
| labels = [self.id2label[p] for p in pred_ids[sent_idx]] |
| confs = max_probs[sent_idx] |
| spans = self._decode_bio(sentence, offsets, labels, confs) |
| out.append(spans) |
| return out |
|
|
| def _decode_bio( |
| self, |
| sentence: str, |
| offsets: list[list[int]], |
| labels: list[str], |
| confs: list[float], |
| ) -> list[Span]: |
| valid = {"Chemical", "Disease", "Virus", "Gene"} |
| |
| token_spans: list[tuple[int, int, str, float]] = [] |
| for (s, e), label, conf in zip(offsets, labels, confs): |
| if s == 0 and e == 0: |
| continue |
| if label == "O" or label is None: |
| continue |
| ent = label.split("-", 1)[1] if "-" in label else label |
| if ent not in valid: |
| continue |
| token_spans.append((s, e, ent, conf)) |
|
|
| |
| |
| spans: list[Span] = [] |
| i = 0 |
| while i < len(token_spans): |
| s, e, ent, c = token_spans[i] |
| confs_acc = [c] |
| j = i + 1 |
| while j < len(token_spans): |
| ns, ne, nent, nc = token_spans[j] |
| if nent != ent: |
| break |
| gap = sentence[e:ns] |
| if gap and any(ch.isalnum() for ch in gap): |
| break |
| e = ne |
| confs_acc.append(nc) |
| j += 1 |
| text = sentence[s:e] |
| if text.strip(): |
| spans.append( |
| Span( |
| start=s, |
| end=e, |
| type=ent, |
| text=text, |
| confidence=float(sum(confs_acc) / len(confs_acc)), |
| ) |
| ) |
| i = j |
| return spans |
|
|
|
|
| def get_predictor() -> NERPredictor: |
| global _predictor |
| if _predictor is None: |
| _predictor = NERPredictor() |
| return _predictor |
|
|