Spaces:
Sleeping
Sleeping
Create model_infer.py
Browse files- model_infer.py +44 -0
model_infer.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, torch
|
| 2 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
|
| 4 |
+
# Point these to your local files in the Space
|
| 5 |
+
MODEL_DIR = "." # or "models/aphasia_model"
|
| 6 |
+
MODEL_BIN = f"{MODEL_DIR}/pytorch_model.bin"
|
| 7 |
+
|
| 8 |
+
class AphasiaClassifier:
|
| 9 |
+
def __init__(self, model_dir: str = MODEL_DIR, device: str | None = None):
|
| 10 |
+
self.cfg = AutoConfig.from_pretrained(model_dir)
|
| 11 |
+
self.tok = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
| 12 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 13 |
+
model_dir,
|
| 14 |
+
config=self.cfg,
|
| 15 |
+
state_dict=torch.load(MODEL_BIN, map_location="cpu")
|
| 16 |
+
)
|
| 17 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
self.model.to(self.device)
|
| 19 |
+
self.model.eval()
|
| 20 |
+
self.id2label = getattr(self.cfg, "id2label", {i: str(i) for i in range(self.cfg.num_labels)})
|
| 21 |
+
|
| 22 |
+
def _prepare_text(self, json_path: str) -> str:
|
| 23 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 24 |
+
data = json.load(f)
|
| 25 |
+
# Example: concatenate utterances; customize to your feature logic
|
| 26 |
+
texts = [u["text"] for u in data.get("utterances", []) if u.get("text")]
|
| 27 |
+
return "\n".join(texts) if texts else ""
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def predict_from_json(self, json_path: str) -> dict:
|
| 31 |
+
text = self._prepare_text(json_path)
|
| 32 |
+
if not text.strip():
|
| 33 |
+
return {"label": None, "score": 0.0, "probs": {}}
|
| 34 |
+
|
| 35 |
+
enc = self.tok(text, truncation=True, max_length=2048, return_tensors="pt")
|
| 36 |
+
enc = {k: v.to(self.device) for k, v in enc.items()}
|
| 37 |
+
logits = self.model(**enc).logits[0]
|
| 38 |
+
probs = torch.softmax(logits, dim=-1).cpu().tolist()
|
| 39 |
+
|
| 40 |
+
label_idx = int(torch.argmax(logits).item())
|
| 41 |
+
label = self.id2label.get(label_idx, str(label_idx))
|
| 42 |
+
probs_named = {self.id2label.get(i, str(i)): float(p) for i, p in enumerate(probs)}
|
| 43 |
+
|
| 44 |
+
return {"label": label, "score": float(max(probs)), "probs": probs_named}
|