""" Inference pipeline. Loads ONLY the trained CriterionWiseScoringSystem checkpoint (best_model_v5.pt) locally. The CodeT5 encoder lives in a separate Hugging Face Space — we call it over HTTP for hidden states. The Qwen explanation LLM is called via the Hugging Face Inference API. See model/encoder_client.py and model/explainer.py. Net effect: the deployment server only needs the small .pt checkpoint plus the cross-attention layers in memory (~150–200 MB). No transformers / no CodeT5 / no Qwen on the server. """ import logging import os import time import torch from .architecture import CriterionWiseScoringSystem from .signals import extract_signals_from_attn from .explainer import build_explanation, normalize_signals from .encoder_client import encode_text logger = logging.getLogger("inference") # ── Config (overridable via env vars) ────────────────────────────────────── CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "best_model_v5.pt") # Optional: pull the checkpoint from a Hugging Face Hub model repo at boot. # Useful for deployments where you don't want the .pt baked into the git repo. # If set, takes precedence over CHECKPOINT_PATH. CHECKPOINT_HF_REPO = os.getenv("CHECKPOINT_HF_REPO", "").strip() CHECKPOINT_HF_FILENAME = os.getenv("CHECKPOINT_HF_FILENAME", "best_model_v5.pt") CHECKPOINT_HF_REVISION = os.getenv("CHECKPOINT_HF_REVISION", "main") EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1024")) N_HEADS = int(os.getenv("N_HEADS", "8")) N_CROSS_LAYERS = int(os.getenv("N_CROSS_LAYERS", "2")) DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") # ── Module-level singletons (initialised by load_pipeline) ───────────────── _model = None _device = None def _resolve_checkpoint_path() -> str: """ Return a local filesystem path to the checkpoint, downloading from the Hugging Face Hub if CHECKPOINT_HF_REPO is set. Otherwise falls back to the local CHECKPOINT_PATH. huggingface_hub caches by content hash under ~/.cache/huggingface/hub, so repeated boots inside the same container don't re-download. """ if CHECKPOINT_HF_REPO: # Lazy import: only needed when the env var is actually set, so local # dev (with the .pt sitting next to app.py) doesn't pay the import cost. from huggingface_hub import hf_hub_download token = os.getenv("HF_TOKEN") or None logger.info( "Downloading checkpoint from HF Hub: repo=%s file=%s rev=%s", CHECKPOINT_HF_REPO, CHECKPOINT_HF_FILENAME, CHECKPOINT_HF_REVISION, ) t0 = time.perf_counter() path = hf_hub_download( repo_id=CHECKPOINT_HF_REPO, filename=CHECKPOINT_HF_FILENAME, revision=CHECKPOINT_HF_REVISION, token=token, ) logger.info( "Checkpoint ready at %s (fetch/cache took %.1fs)", path, time.perf_counter() - t0, ) return path if not os.path.exists(CHECKPOINT_PATH): raise FileNotFoundError( f"Checkpoint not found at '{CHECKPOINT_PATH}' and CHECKPOINT_HF_REPO is unset. " f"Either place best_model_v5.pt in the project root, set CHECKPOINT_PATH, " f"or set CHECKPOINT_HF_REPO to a HF model repo (e.g. 'user/pseudoscorex-checkpoint')." ) return CHECKPOINT_PATH def load_pipeline(): """Load the trained scoring model. Call once at boot.""" global _model, _device if _model is not None: return # already loaded t0 = time.perf_counter() device = torch.device(DEVICE) logger.info("Loading scoring head on device: %s", device) model = CriterionWiseScoringSystem( embedding_dim=EMBEDDING_DIM, n_heads=N_HEADS, n_cross_layers=N_CROSS_LAYERS, dropout=0.1, ).to(device) checkpoint_path = _resolve_checkpoint_path() logger.info("Loading checkpoint: %s", checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict) model.eval() _model = model _device = device logger.info("Pipeline ready in %.1fs total.", time.perf_counter() - t0) def _validate_payload(question, answer, criteria): if not isinstance(question, str) or not question.strip(): raise ValueError("'question' must be a non-empty string") if not isinstance(answer, str) or not answer.strip(): raise ValueError("'answer' must be a non-empty string") if not isinstance(criteria, list) or not criteria: raise ValueError("'criteria' must be a non-empty list") for i, c in enumerate(criteria): if not isinstance(c, dict): raise ValueError(f"criteria[{i}] must be an object") if "name" not in c: raise ValueError(f"criteria[{i}] missing 'name'") if "max_score" not in c: raise ValueError(f"criteria[{i}] missing 'max_score'") try: float(c["max_score"]) except (TypeError, ValueError): raise ValueError(f"criteria[{i}].max_score must be numeric") # ── Public API ───────────────────────────────────────────────────────────── def predict(question, answer, criteria): """ Score one (question, answer) pair against a list of rubric criteria. Same signature and return shape as the original implementation; the encoder is now remote and the LLM uses the HF Inference API. """ if _model is None: raise RuntimeError("Pipeline not loaded. Call load_pipeline() first.") _validate_payload(question, answer, criteria) t_total = time.perf_counter() logger.info( "Starting prediction: question=%d chars, answer=%d chars, %d criteria", len(question), len(answer), len(criteria), ) # Encode question and answer ONCE — they are shared across criteria t_enc = time.perf_counter() q_hidden, q_mask, question_tokens = encode_text(question, _device) a_hidden, a_mask, answer_tokens = encode_text(answer, _device) logger.info("Encoded question + answer via Space (%.2fs)", time.perf_counter() - t_enc) criterion_results = [] total_score = 0.0 max_total = 0.0 n_crit = len(criteria) for idx, crit in enumerate(criteria, start=1): t_crit = time.perf_counter() name = crit["name"] desc = crit.get("description") or name max_s = float(crit["max_score"]) logger.info("[%d/%d] Criterion '%s' (max=%.0f) → scoring…", idx, n_crit, name, max_s) # Notebook always prepends the sentinel — the Space's # tokenizer knows about it (we add it during Space boot). c_text = " " + desc c_hidden, c_mask, criterion_tokens = encode_text(c_text, _device) t_fwd = time.perf_counter() with torch.no_grad(): pred_norm, attn_dict = _model( q_hidden, a_hidden, c_hidden, A_mask=a_mask, c_mask=c_mask, return_attn=True, ) pred_norm = float(pred_norm.item()) score_val = pred_norm * max_s logger.info( "[%d/%d] score=%.2f/%.0f (pred_norm=%.4f) in %.2fs", idx, n_crit, score_val, max_s, pred_norm, time.perf_counter() - t_fwd, ) signals = extract_signals_from_attn( attn_dict, answer_tokens=answer_tokens, criterion_tokens=criterion_tokens, question_tokens=question_tokens, ) signals = normalize_signals(signals) t_expl = time.perf_counter() logger.info("[%d/%d] Generating explanation via HF Inference API…", idx, n_crit) explanation = build_explanation( criterion_name=name, pred_norm=pred_norm, score=score_val, max_score=max_s, signals=signals, question=question, answer=answer, criterion_desc=desc, ) logger.info( "[%d/%d] Explanation ready in %.2fs (criterion total %.2fs)", idx, n_crit, time.perf_counter() - t_expl, time.perf_counter() - t_crit, ) criterion_results.append({ "criterion": name, "criterion_description": desc, "score": round(score_val, 3), "max_score": max_s, "pred_norm": round(pred_norm, 4), "explanation": explanation, }) total_score += score_val max_total += max_s percentage = (total_score / max_total * 100.0) if max_total > 0 else 0.0 overall_explanation = _build_overall_explanation( total_score, max_total, criterion_results ) logger.info( "Prediction done: total=%.2f/%.0f (%.1f%%) in %.2fs", total_score, max_total, percentage, time.perf_counter() - t_total, ) return { "question": question, "answer": answer, "total_score": round(total_score, 3), "max_total_score": max_total, "percentage": round(percentage, 2), "results": criterion_results, "overall_explanation": overall_explanation, } def _build_overall_explanation(total, max_total, criterion_results): """Tiny narrative stitched together from the per-criterion results.""" if max_total == 0: return "No criteria were scorable." pct = total / max_total * 100.0 if pct >= 85: band = "excellent" elif pct >= 70: band = "strong" elif pct >= 50: band = "partial" else: band = "weak" strong = [r["criterion"] for r in criterion_results if r["pred_norm"] >= 0.8] weak = [r["criterion"] for r in criterion_results if r["pred_norm"] < 0.5] parts = [f"Overall {band} performance: {total:.2f} / {max_total:.0f} ({pct:.1f}%)."] if strong: parts.append("Strongest on: " + ", ".join(strong) + ".") if weak: parts.append("Needs work on: " + ", ".join(weak) + ".") return " ".join(parts)