| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| try: |
| from .config import ( |
| CALIBRATION_ARTIFACTS_DIR, |
| DECISION_PHASE_HEAD_CONFIG, |
| INTENT_HEAD_CONFIG, |
| MULTITASK_INTENT_MODEL_DIR, |
| SUBTYPE_HEAD_CONFIG, |
| ) |
| from .multitask_model import MultiTaskIntentModel, MultiTaskLabelSizes |
| except ImportError: |
| from config import ( |
| CALIBRATION_ARTIFACTS_DIR, |
| DECISION_PHASE_HEAD_CONFIG, |
| INTENT_HEAD_CONFIG, |
| MULTITASK_INTENT_MODEL_DIR, |
| SUBTYPE_HEAD_CONFIG, |
| ) |
| from multitask_model import MultiTaskIntentModel, MultiTaskLabelSizes |
|
|
|
|
| def round_score(value: float) -> float: |
| return round(float(value), 4) |
|
|
|
|
| TASK_TO_CONFIG = { |
| "intent_type": INTENT_HEAD_CONFIG, |
| "intent_subtype": SUBTYPE_HEAD_CONFIG, |
| "decision_phase": DECISION_PHASE_HEAD_CONFIG, |
| } |
|
|
| TASK_TO_LOGIT_KEY = { |
| "intent_type": "intent_type_logits", |
| "intent_subtype": "intent_subtype_logits", |
| "decision_phase": "decision_phase_logits", |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class CalibrationState: |
| calibrated: bool |
| temperature: float |
| confidence_threshold: float |
|
|
|
|
| class MultiTaskRuntime: |
| def __init__(self, model_dir: Path): |
| self.model_dir = model_dir |
| self._tokenizer = None |
| self._model = None |
| self._metadata = None |
| self._predict_batch_size = 32 |
|
|
| @property |
| def metadata(self) -> dict: |
| if self._metadata is None: |
| metadata_path = self.model_dir / "metadata.json" |
| if not metadata_path.exists(): |
| raise FileNotFoundError( |
| f"Missing multitask metadata at {metadata_path}. Run python3 training/train_multitask_intent.py first." |
| ) |
| self._metadata = json.loads(metadata_path.read_text(encoding="utf-8")) |
| return self._metadata |
|
|
| @property |
| def tokenizer(self): |
| if self._tokenizer is None: |
| self._tokenizer = AutoTokenizer.from_pretrained(str(self.model_dir)) |
| return self._tokenizer |
|
|
| @property |
| def model(self) -> MultiTaskIntentModel: |
| if self._model is None: |
| weights_path = self.model_dir / "multitask_model.pt" |
| if not weights_path.exists(): |
| raise FileNotFoundError( |
| f"Missing multitask weights at {weights_path}. Run python3 training/train_multitask_intent.py first." |
| ) |
| payload = torch.load(weights_path, map_location="cpu") |
| label_sizes = MultiTaskLabelSizes( |
| intent_type=len(TASK_TO_CONFIG["intent_type"].labels), |
| intent_subtype=len(TASK_TO_CONFIG["intent_subtype"].labels), |
| decision_phase=len(TASK_TO_CONFIG["decision_phase"].labels), |
| ) |
| model = MultiTaskIntentModel(self.metadata["base_model_name"], label_sizes) |
| model.load_state_dict(payload["state_dict"], strict=True) |
| model.eval() |
| self._model = model |
| return self._model |
|
|
| def _encode(self, texts: list[str], max_length: int) -> dict[str, torch.Tensor]: |
| encoded = self.tokenizer( |
| texts, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=max_length, |
| ) |
| return {"input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"]} |
|
|
| def _predict_logits(self, task: str, texts: list[str]) -> torch.Tensor: |
| config = TASK_TO_CONFIG[task] |
| inputs = self._encode(texts, config.max_length) |
| with torch.inference_mode(): |
| outputs = self.model(**inputs) |
| return outputs[TASK_TO_LOGIT_KEY[task]] |
|
|
| def predict_all_heads_batch( |
| self, texts: list[str] |
| ) -> dict[str, torch.Tensor]: |
| """Single encoder pass returning logits for all three heads at once. |
| |
| This is the hot-path entry point. Compared with calling |
| ``_predict_logits`` once per head it cuts the number of DistilBERT |
| forward passes from 3 → 1, roughly halving CPU latency for a single |
| query. |
| |
| Returns |
| ------- |
| dict with keys ``intent_type_logits``, ``intent_subtype_logits``, |
| ``decision_phase_logits`` — raw (pre-softmax) float tensors of shape |
| ``(len(texts), n_classes_for_head)``. |
| """ |
| |
| |
| max_len = max(cfg.max_length for cfg in TASK_TO_CONFIG.values()) |
| inputs = self._encode(texts, max_len) |
| with torch.inference_mode(): |
| outputs = self.model(**inputs) |
| return { |
| "intent_type_logits": outputs["intent_type_logits"], |
| "intent_subtype_logits": outputs["intent_subtype_logits"], |
| "decision_phase_logits": outputs["decision_phase_logits"], |
| } |
|
|
|
|
| class MultiTaskHeadProxy: |
| def __init__(self, task: str): |
| if task not in TASK_TO_CONFIG: |
| raise ValueError(f"Unsupported multitask head: {task}") |
| self.task = task |
| self.config = TASK_TO_CONFIG[task] |
| self.runtime = get_multitask_runtime() |
| self._calibration = None |
|
|
| @property |
| def tokenizer(self): |
| return self.runtime.tokenizer |
|
|
| @property |
| def model(self): |
| proxy = self |
|
|
| class _TaskModelView: |
| config = type("ConfigView", (), {"id2label": proxy.config.id2label})() |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| with torch.inference_mode(): |
| outputs = proxy.runtime.model(input_ids=input_ids, attention_mask=attention_mask) |
| logits = outputs[TASK_TO_LOGIT_KEY[proxy.task]] |
| return type("OutputView", (), {"logits": logits})() |
|
|
| __call__ = forward |
|
|
| return _TaskModelView() |
|
|
| @property |
| def forward_arg_names(self) -> set[str]: |
| return {"input_ids", "attention_mask"} |
|
|
| @property |
| def calibration(self) -> CalibrationState: |
| if self._calibration is None: |
| calibrated = False |
| temperature = 1.0 |
| confidence_threshold = self.config.default_confidence_threshold |
| calibration_path = CALIBRATION_ARTIFACTS_DIR / f"{self.task}.json" |
| if calibration_path.exists(): |
| payload = json.loads(calibration_path.read_text(encoding="utf-8")) |
| calibrated = bool(payload.get("calibrated", True)) |
| temperature = float(payload.get("temperature", 1.0)) |
| confidence_threshold = float(payload.get("confidence_threshold", confidence_threshold)) |
| self._calibration = CalibrationState( |
| calibrated=calibrated, |
| temperature=max(temperature, 1e-3), |
| confidence_threshold=min(max(confidence_threshold, 0.0), 1.0), |
| ) |
| return self._calibration |
|
|
| def _predict_probs(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor]: |
| logits = self.runtime._predict_logits(self.task, texts) |
| with torch.inference_mode(): |
| raw_probs = torch.softmax(logits, dim=-1) |
| calibrated_probs = torch.softmax(logits / self.calibration.temperature, dim=-1) |
| return raw_probs, calibrated_probs |
|
|
| def predict_probs_from_logits( |
| self, logits: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute calibrated probs from pre-computed logits (hot-path helper). |
| |
| Called by ``classify_query_fused`` after a single shared encoder pass |
| so that each ``MultiTaskHeadProxy`` does not re-run the encoder. |
| """ |
| with torch.inference_mode(): |
| raw_probs = torch.softmax(logits, dim=-1) |
| calibrated_probs = torch.softmax(logits / self.calibration.temperature, dim=-1) |
| return raw_probs, calibrated_probs |
|
|
| def predict_from_logits( |
| self, logits: torch.Tensor, confidence_threshold: float | None = None |
| ) -> dict: |
| """Return a single prediction dict from pre-computed logits.""" |
| effective_threshold = ( |
| self.calibration.confidence_threshold |
| if confidence_threshold is None |
| else min(max(float(confidence_threshold), 0.0), 1.0) |
| ) |
| raw_probs, calibrated_probs = self.predict_probs_from_logits(logits.unsqueeze(0)) |
| raw_row = raw_probs[0] |
| calibrated_row = calibrated_probs[0] |
| pred_id = int(torch.argmax(calibrated_row).item()) |
| confidence = float(calibrated_row[pred_id].item()) |
| raw_confidence = float(raw_row[pred_id].item()) |
| return { |
| "label": self.config.id2label[pred_id], |
| "confidence": round_score(confidence), |
| "raw_confidence": round_score(raw_confidence), |
| "confidence_threshold": round_score(effective_threshold), |
| "calibrated": self.calibration.calibrated, |
| "meets_confidence_threshold": confidence >= effective_threshold, |
| } |
|
|
| def predict_probs_batch(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor]: |
| if not texts: |
| empty = torch.empty((0, len(self.config.labels)), dtype=torch.float32) |
| return empty, empty |
| raw_chunks: list[torch.Tensor] = [] |
| calibrated_chunks: list[torch.Tensor] = [] |
| for start in range(0, len(texts), self.runtime._predict_batch_size): |
| batch = texts[start : start + self.runtime._predict_batch_size] |
| raw, calibrated = self._predict_probs(batch) |
| raw_chunks.append(raw.detach().cpu()) |
| calibrated_chunks.append(calibrated.detach().cpu()) |
| return torch.cat(raw_chunks, dim=0), torch.cat(calibrated_chunks, dim=0) |
|
|
| def predict_batch(self, texts: list[str], confidence_threshold: float | None = None) -> list[dict]: |
| if not texts: |
| return [] |
| effective_threshold = ( |
| self.calibration.confidence_threshold |
| if confidence_threshold is None |
| else min(max(float(confidence_threshold), 0.0), 1.0) |
| ) |
| predictions: list[dict] = [] |
| for start in range(0, len(texts), self.runtime._predict_batch_size): |
| batch = texts[start : start + self.runtime._predict_batch_size] |
| raw_probs, calibrated_probs = self._predict_probs(batch) |
| for raw_row, calibrated_row in zip(raw_probs, calibrated_probs): |
| pred_id = int(torch.argmax(calibrated_row).item()) |
| confidence = float(calibrated_row[pred_id].item()) |
| raw_confidence = float(raw_row[pred_id].item()) |
| predictions.append( |
| { |
| "label": self.config.id2label[pred_id], |
| "confidence": round_score(confidence), |
| "raw_confidence": round_score(raw_confidence), |
| "confidence_threshold": round_score(effective_threshold), |
| "calibrated": self.calibration.calibrated, |
| "meets_confidence_threshold": confidence >= effective_threshold, |
| } |
| ) |
| return predictions |
|
|
| def predict(self, text: str, confidence_threshold: float | None = None) -> dict: |
| return self.predict_batch([text], confidence_threshold=confidence_threshold)[0] |
|
|
| def status(self) -> dict: |
| return { |
| "head": self.task, |
| "model_path": str(self.runtime.model_dir), |
| "calibration_path": str(CALIBRATION_ARTIFACTS_DIR / f"{self.task}.json"), |
| "ready": (self.runtime.model_dir / "multitask_model.pt").exists(), |
| "calibrated": self.calibration.calibrated, |
| } |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_multitask_runtime() -> MultiTaskRuntime: |
| return MultiTaskRuntime(MULTITASK_INTENT_MODEL_DIR) |
|
|