File size: 2,031 Bytes
8ea95f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import json, torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

# Point these to your local files in the Space
MODEL_DIR = "."  # or "models/aphasia_model"
MODEL_BIN = f"{MODEL_DIR}/pytorch_model.bin"

class AphasiaClassifier:
    def __init__(self, model_dir: str = MODEL_DIR, device: str | None = None):
        self.cfg = AutoConfig.from_pretrained(model_dir)
        self.tok = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_dir,
            config=self.cfg,
            state_dict=torch.load(MODEL_BIN, map_location="cpu")
        )
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        self.id2label = getattr(self.cfg, "id2label", {i: str(i) for i in range(self.cfg.num_labels)})

    def _prepare_text(self, json_path: str) -> str:
        with open(json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        # Example: concatenate utterances; customize to your feature logic
        texts = [u["text"] for u in data.get("utterances", []) if u.get("text")]
        return "\n".join(texts) if texts else ""

    @torch.no_grad()
    def predict_from_json(self, json_path: str) -> dict:
        text = self._prepare_text(json_path)
        if not text.strip():
            return {"label": None, "score": 0.0, "probs": {}}

        enc = self.tok(text, truncation=True, max_length=2048, return_tensors="pt")
        enc = {k: v.to(self.device) for k, v in enc.items()}
        logits = self.model(**enc).logits[0]
        probs = torch.softmax(logits, dim=-1).cpu().tolist()

        label_idx = int(torch.argmax(logits).item())
        label = self.id2label.get(label_idx, str(label_idx))
        probs_named = {self.id2label.get(i, str(i)): float(p) for i, p in enumerate(probs)}

        return {"label": label, "score": float(max(probs)), "probs": probs_named}