leblanciii's picture
Upload 6-head multi-task classifier (peril+severity+category+fire_sub+relevance+actionability)
5dfadff verified
import json, torch, torch.nn as nn
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
PERIL_LABELS = ["fire","flood","named_windstorm","construction_theft","transient_population","civil_unrest","earthquake"]
SEVERITY_LABELS = ["low","medium","high","critical"]
CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"]
FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"]
ACTIONABILITY_LABELS = ["irrelevant","informational","notable","actionable"]
class MultiTaskClassifier(nn.Module):
def __init__(self, model_name, np, ns, nc, nf, na=4):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
h = self.encoder.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.peril_head = nn.Linear(h, np)
self.severity_head = nn.Linear(h, ns)
self.category_head = nn.Linear(h, nc)
self.fire_sub_head = nn.Linear(h, nf)
self.relevance_head = nn.Linear(h, 1)
self.actionability_head = nn.Linear(h, na)
def forward(self, input_ids, attention_mask=None):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.dropout(out.last_hidden_state[:, 0, :])
return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled),
"category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled),
"relevance_logits": self.relevance_head(pooled), "actionability_logits": self.actionability_head(pooled)}
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS), len(ACTIONABILITY_LABELS))
w = Path(path) / "pytorch_model.bin"
if w.exists():
self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True))
self.model.to(self.device)
self.model.eval()
def __call__(self, data):
text = data.get("inputs", "")
if isinstance(text, list): text = text[0]
params = data.get("parameters", {})
include_relevance = params.get("include_relevance", False)
include_actionability = params.get("include_actionability", False)
inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs)
pp = torch.sigmoid(out["peril_logits"])[0].cpu().tolist()
sp = torch.softmax(out["severity_logits"], -1)[0].cpu().tolist()
si = int(out["severity_logits"].argmax(-1)[0].cpu())
cp = torch.softmax(out["category_logits"], -1)[0].cpu().tolist()
ci = int(out["category_logits"].argmax(-1)[0].cpu())
fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist()
fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu())
result = {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)},
"severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)},
"category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)},
"fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}}
if include_relevance:
rel_score = float(torch.sigmoid(out["relevance_logits"].squeeze(-1))[0].cpu())
result["relevance"] = {"score": round(rel_score, 4), "label": "relevant" if rel_score > 0.5 else "irrelevant"}
if include_actionability:
ap = torch.softmax(out["actionability_logits"], -1)[0].cpu().tolist()
ai = int(out["actionability_logits"].argmax(-1)[0].cpu())
result["actionability"] = {"label": ACTIONABILITY_LABELS[ai], "confidence": round(ap[ai], 4)}
return result