Spaces:
Sleeping
Sleeping
File size: 4,550 Bytes
61ff229 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | #!/usr/bin/env python3
"""General-purpose TinyModel runtime utilities.
This module extends usage beyond plain classification by exposing:
- class probabilities
- sentence embeddings from the encoder
- semantic similarity scoring
- nearest-neighbor retrieval over a candidate set
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@dataclass
class RetrievalHit:
text: str
score: float
index: int
class TinyModelRuntime:
"""Inference helper around TinyModel classification checkpoints."""
def __init__(
self,
model_id_or_path: str,
*,
device: str | None = None,
max_length: int = 128,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_id_or_path)
self.model.eval()
self.max_length = max_length
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(self.device)
def _encoder_backbone(self):
"""Return the base encoder (BERT, DistilBERT, RoBERTa, etc.)."""
m = self.model
for name in ("bert", "distilbert", "roberta", "electra", "camembert", "xlm_roberta"):
if hasattr(m, name):
return getattr(m, name)
raise AttributeError(
"Could not find a supported encoder backbone on this model; "
"embeddings need BERT/DistilBERT/RoBERTa-style checkpoints."
)
def classify(self, texts: Sequence[str]) -> list[dict[str, float]]:
"""Return per-label probabilities for each input text."""
encoded = self.tokenizer(
list(texts),
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt",
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.inference_mode():
logits = self.model(**encoded).logits
probs = F.softmax(logits, dim=-1).cpu()
id2label = self.model.config.id2label
out: list[dict[str, float]] = []
for row in probs:
item = {id2label[i]: float(row[i]) for i in range(row.shape[0])}
out.append(item)
return out
def embed(self, texts: Sequence[str], *, normalize: bool = True) -> torch.Tensor:
"""Generate pooled sentence embeddings from the transformer encoder ([CLS] / first token)."""
encoded = self.tokenizer(
list(texts),
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt",
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.inference_mode():
backbone = self._encoder_backbone()
# Only pass ids/mask so DistilBERT and BERT both accept the call.
hidden = backbone(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
return_dict=True,
).last_hidden_state
cls = hidden[:, 0, :]
if normalize:
cls = F.normalize(cls, p=2, dim=1)
return cls.cpu()
def similarity(self, text_a: str, text_b: str) -> float:
"""Cosine similarity between two texts using encoder embeddings."""
embs = self.embed([text_a, text_b], normalize=True)
score = F.cosine_similarity(embs[0].unsqueeze(0), embs[1].unsqueeze(0))
return float(score.item())
def retrieve(
self,
query: str,
candidates: Sequence[str],
*,
top_k: int = 3,
) -> list[RetrievalHit]:
"""Return top-k semantically closest candidates to query."""
if not candidates:
return []
texts = [query, *candidates]
embs = self.embed(texts, normalize=True)
query_emb = embs[0:1]
cand_embs = embs[1:]
scores = (query_emb @ cand_embs.T).squeeze(0)
top_k = max(1, min(top_k, scores.shape[0]))
vals, idxs = torch.topk(scores, k=top_k)
hits: list[RetrievalHit] = []
for score, idx in zip(vals.tolist(), idxs.tolist()):
hits.append(RetrievalHit(text=candidates[idx], score=float(score), index=idx))
return hits
|