import json, torch, torch.nn as nn from pathlib import Path from transformers import AutoModel, AutoTokenizer PERIL_LABELS = ["fire","flood","named_windstorm","construction_theft","transient_population","civil_unrest","earthquake"] SEVERITY_LABELS = ["low","medium","high","critical"] CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"] FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"] ACTIONABILITY_LABELS = ["irrelevant","informational","notable","actionable"] class MultiTaskClassifier(nn.Module): def __init__(self, model_name, np, ns, nc, nf, na=4): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) h = self.encoder.config.hidden_size self.dropout = nn.Dropout(0.1) self.peril_head = nn.Linear(h, np) self.severity_head = nn.Linear(h, ns) self.category_head = nn.Linear(h, nc) self.fire_sub_head = nn.Linear(h, nf) self.relevance_head = nn.Linear(h, 1) self.actionability_head = nn.Linear(h, na) def forward(self, input_ids, attention_mask=None): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = self.dropout(out.last_hidden_state[:, 0, :]) return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled), "category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled), "relevance_logits": self.relevance_head(pooled), "actionability_logits": self.actionability_head(pooled)} class EndpointHandler: def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS), len(ACTIONABILITY_LABELS)) w = Path(path) / "pytorch_model.bin" if w.exists(): self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True)) self.model.to(self.device) self.model.eval() def __call__(self, data): text = data.get("inputs", "") if isinstance(text, list): text = text[0] params = data.get("parameters", {}) include_relevance = params.get("include_relevance", False) include_actionability = params.get("include_actionability", False) inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.model(**inputs) pp = torch.sigmoid(out["peril_logits"])[0].cpu().tolist() sp = torch.softmax(out["severity_logits"], -1)[0].cpu().tolist() si = int(out["severity_logits"].argmax(-1)[0].cpu()) cp = torch.softmax(out["category_logits"], -1)[0].cpu().tolist() ci = int(out["category_logits"].argmax(-1)[0].cpu()) fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist() fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu()) result = {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)}, "severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)}, "category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)}, "fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}} if include_relevance: rel_score = float(torch.sigmoid(out["relevance_logits"].squeeze(-1))[0].cpu()) result["relevance"] = {"score": round(rel_score, 4), "label": "relevant" if rel_score > 0.5 else "irrelevant"} if include_actionability: ap = torch.softmax(out["actionability_logits"], -1)[0].cpu().tolist() ai = int(out["actionability_logits"].argmax(-1)[0].cpu()) result["actionability"] = {"label": ACTIONABILITY_LABELS[ai], "confidence": round(ap[ai], 4)} return result