Ellie5757575757 commited on
Commit
a857c0f
·
verified ·
1 Parent(s): f869b0b

Create model_infer.py

Browse files
Files changed (1) hide show
  1. model_infer.py +44 -0
model_infer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, torch
2
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+ # Point these to your local files in the Space
5
+ MODEL_DIR = "." # or "models/aphasia_model"
6
+ MODEL_BIN = f"{MODEL_DIR}/pytorch_model.bin"
7
+
8
+ class AphasiaClassifier:
9
+ def __init__(self, model_dir: str = MODEL_DIR, device: str | None = None):
10
+ self.cfg = AutoConfig.from_pretrained(model_dir)
11
+ self.tok = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
12
+ self.model = AutoModelForSequenceClassification.from_pretrained(
13
+ model_dir,
14
+ config=self.cfg,
15
+ state_dict=torch.load(MODEL_BIN, map_location="cpu")
16
+ )
17
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+ self.id2label = getattr(self.cfg, "id2label", {i: str(i) for i in range(self.cfg.num_labels)})
21
+
22
+ def _prepare_text(self, json_path: str) -> str:
23
+ with open(json_path, "r", encoding="utf-8") as f:
24
+ data = json.load(f)
25
+ # Example: concatenate utterances; customize to your feature logic
26
+ texts = [u["text"] for u in data.get("utterances", []) if u.get("text")]
27
+ return "\n".join(texts) if texts else ""
28
+
29
+ @torch.no_grad()
30
+ def predict_from_json(self, json_path: str) -> dict:
31
+ text = self._prepare_text(json_path)
32
+ if not text.strip():
33
+ return {"label": None, "score": 0.0, "probs": {}}
34
+
35
+ enc = self.tok(text, truncation=True, max_length=2048, return_tensors="pt")
36
+ enc = {k: v.to(self.device) for k, v in enc.items()}
37
+ logits = self.model(**enc).logits[0]
38
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()
39
+
40
+ label_idx = int(torch.argmax(logits).item())
41
+ label = self.id2label.get(label_idx, str(label_idx))
42
+ probs_named = {self.id2label.get(i, str(i)): float(p) for i, p in enumerate(probs)}
43
+
44
+ return {"label": label, "score": float(max(probs)), "probs": probs_named}