Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import torch, json | |
| from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification | |
| from transformers.models.deberta_v2 import DebertaV2ForSequenceClassification | |
| MODEL_DIR_DEFAULT = os.path.join(os.path.dirname(__file__), "final_model") | |
| def _strip_wrappers(k: str) -> str: | |
| for p in ("model.", "module.", "net."): | |
| if k.startswith(p): return k[len(p):] | |
| return k | |
| def _remap_keys(sd: dict) -> dict: | |
| new = {} | |
| for k, v in sd.items(): | |
| k = _strip_wrappers(k) | |
| if k.startswith("backbone."): | |
| k = "deberta." + k[len("backbone."):] | |
| elif k.startswith(("head.", "heads.", "cls.", "fc.")): | |
| k = "classifier." + k.split(".", 1)[1] | |
| elif k.startswith("encoder."): | |
| k = "deberta." + k | |
| new[k] = v | |
| return new | |
| class UrgencyModel: | |
| def __init__(self, model_dir=MODEL_DIR_DEFAULT, device=None, threshold=0.5): | |
| self.model_dir = model_dir | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| thr_path = os.path.join(model_dir, "thresholds.json") | |
| if os.path.exists(thr_path): | |
| try: | |
| threshold = float(json.load(open(thr_path, encoding="utf-8")).get("urgency", threshold)) | |
| except Exception: | |
| pass | |
| self.threshold = threshold | |
| try: | |
| spaces = json.load(open(os.path.join(model_dir, "label_spaces.json"), encoding="utf-8")) | |
| self.id2label = {int(k): v for k, v in spaces.get("id2label", {}).get("urgency", {}).items()} | |
| except Exception: | |
| self.id2label = {0: "Non-Urgent", 1: "Urgent"} | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) | |
| cfg = AutoConfig.from_pretrained(model_dir, local_files_only=True) | |
| if getattr(cfg, "model_type", None) == "deberta-v2": | |
| self.model = DebertaV2ForSequenceClassification(cfg) | |
| else: | |
| self.model = AutoModelForSequenceClassification.from_config(cfg) | |
| sd = None | |
| binp = os.path.join(model_dir, "pytorch_model.bin") | |
| safep = os.path.join(model_dir, "model.safetensors") | |
| if os.path.exists(binp): | |
| sd = torch.load(binp, map_location="cpu") | |
| if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict): | |
| sd = sd["state_dict"] | |
| elif os.path.exists(safep): | |
| from safetensors.torch import load_file | |
| sd = load_file(safep) | |
| else: | |
| raise FileNotFoundError("No model weights found.") | |
| sd = _remap_keys(sd) | |
| self.model.load_state_dict(sd, strict=False) | |
| self.model.to(self.device).eval() | |
| def predict(self, text: str): | |
| if not text or not text.strip(): | |
| return {"urgency_score": 0.0, "urgent_label": "Non-Urgent", "rationale": "Empty input."} | |
| inputs = self.tokenizer(text, truncation=True, max_length=1024, return_tensors="pt").to(self.device) | |
| logits = self.model(**inputs).logits | |
| if logits.shape[-1] == 1: | |
| score = torch.sigmoid(logits.squeeze(-1)).item() | |
| else: | |
| score = torch.softmax(logits, dim=-1).squeeze(0)[1].item() | |
| label = self.id2label.get(int(score >= self.threshold), "Urgent" if score >= self.threshold else "Non-Urgent") | |
| return {"urgency_score": round(float(score), 4), "urgent_label": label, "rationale": self._cheap_rationale(text)} | |
| def _cheap_rationale(self, text: str, top_n: int = 3): | |
| KEYS = ["shot","shooting","gun","stabbing","blood","not breathing","unconscious", | |
| "heart","chest pain","stroke","seizure","screaming","help now","immediate", | |
| "fire","trapped","domestic","assault","weapon"] | |
| t = text.lower() | |
| hits = [k for k in KEYS if k in t][:top_n] | |
| return "Keywords: " + (", ".join(hits) if hits else "none detected") | |