import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch, json from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification from transformers.models.deberta_v2 import DebertaV2ForSequenceClassification MODEL_DIR_DEFAULT = os.path.join(os.path.dirname(__file__), "final_model") def _strip_wrappers(k: str) -> str: for p in ("model.", "module.", "net."): if k.startswith(p): return k[len(p):] return k def _remap_keys(sd: dict) -> dict: new = {} for k, v in sd.items(): k = _strip_wrappers(k) if k.startswith("backbone."): k = "deberta." + k[len("backbone."):] elif k.startswith(("head.", "heads.", "cls.", "fc.")): k = "classifier." + k.split(".", 1)[1] elif k.startswith("encoder."): k = "deberta." + k new[k] = v return new class UrgencyModel: def __init__(self, model_dir=MODEL_DIR_DEFAULT, device=None, threshold=0.5): self.model_dir = model_dir self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") thr_path = os.path.join(model_dir, "thresholds.json") if os.path.exists(thr_path): try: threshold = float(json.load(open(thr_path, encoding="utf-8")).get("urgency", threshold)) except Exception: pass self.threshold = threshold try: spaces = json.load(open(os.path.join(model_dir, "label_spaces.json"), encoding="utf-8")) self.id2label = {int(k): v for k, v in spaces.get("id2label", {}).get("urgency", {}).items()} except Exception: self.id2label = {0: "Non-Urgent", 1: "Urgent"} self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) cfg = AutoConfig.from_pretrained(model_dir, local_files_only=True) if getattr(cfg, "model_type", None) == "deberta-v2": self.model = DebertaV2ForSequenceClassification(cfg) else: self.model = AutoModelForSequenceClassification.from_config(cfg) sd = None binp = os.path.join(model_dir, "pytorch_model.bin") safep = os.path.join(model_dir, "model.safetensors") if os.path.exists(binp): sd = torch.load(binp, map_location="cpu") if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict): sd = sd["state_dict"] elif os.path.exists(safep): from safetensors.torch import load_file sd = load_file(safep) else: raise FileNotFoundError("No model weights found.") sd = _remap_keys(sd) self.model.load_state_dict(sd, strict=False) self.model.to(self.device).eval() @torch.inference_mode() def predict(self, text: str): if not text or not text.strip(): return {"urgency_score": 0.0, "urgent_label": "Non-Urgent", "rationale": "Empty input."} inputs = self.tokenizer(text, truncation=True, max_length=1024, return_tensors="pt").to(self.device) logits = self.model(**inputs).logits if logits.shape[-1] == 1: score = torch.sigmoid(logits.squeeze(-1)).item() else: score = torch.softmax(logits, dim=-1).squeeze(0)[1].item() label = self.id2label.get(int(score >= self.threshold), "Urgent" if score >= self.threshold else "Non-Urgent") return {"urgency_score": round(float(score), 4), "urgent_label": label, "rationale": self._cheap_rationale(text)} def _cheap_rationale(self, text: str, top_n: int = 3): KEYS = ["shot","shooting","gun","stabbing","blood","not breathing","unconscious", "heart","chest pain","stroke","seizure","screaming","help now","immediate", "fire","trapped","domestic","assault","weapon"] t = text.lower() hits = [k for k in KEYS if k in t][:top_n] return "Keywords: " + (", ".join(hits) if hits else "none detected")