""" Model inference module for SlushSense. Loads the trained DistilBERT model and runs predictions. """ import json import logging import torch from pathlib import Path from typing import Optional, Dict, Any logger = logging.getLogger(__name__) # Paths - relative to this file's location SRC_DIR = Path(__file__).resolve().parent BASE_DIR = SRC_DIR.parent MODELS_DIR = BASE_DIR / "models" MODEL_DIR = MODELS_DIR / "deberta_commercial_top25" # Model settings MODEL_NAME = "distilbert-base-uncased" MAX_LEN = 384 DEVICE = "cpu" # Singleton predictor _predictor: Optional["CommercialPredictor"] = None _load_error: Optional[str] = None class CommercialPredictor: """Wrapper for the trained DistilBERT classifier.""" def __init__(self, model_dir: Path): from transformers import AutoTokenizer, AutoModelForSequenceClassification self.device = torch.device(DEVICE) model_dir = Path(model_dir) # Load tokenizer if (model_dir / "tokenizer.json").exists(): self.tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) else: self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Load model self.model = AutoModelForSequenceClassification.from_pretrained( str(model_dir), num_labels=1 ) self.model.to(self.device) self.model.eval() # Load threshold self.threshold = 0.5 thr_path = model_dir / "threshold.json" if thr_path.exists(): with open(thr_path) as f: self.threshold = json.load(f).get("threshold", 0.5) logger.info(f"Model loaded from {model_dir}, threshold={self.threshold}") def predict(self, text: str) -> Dict[str, Any]: """Run prediction on text.""" if not text or not text.strip(): return {"prediction": None, "probability": None, "label": None, "confidence": None} inputs = self.tokenizer( text, truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt", ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): logits = self.model(**inputs).logits.squeeze(-1) prob = torch.sigmoid(logits).cpu().item() pred = 1 if prob >= self.threshold else 0 label = "good" if pred == 1 else "average to low" confidence = abs(prob - self.threshold) / max(self.threshold, 1 - self.threshold) return { "prediction": pred, "probability": round(prob, 4), "label": label, "confidence": round(confidence, 4), } def _get_predictor() -> Optional[CommercialPredictor]: """Lazy-load the predictor singleton.""" global _predictor, _load_error if _predictor is not None: return _predictor if _load_error is not None: return None try: _predictor = CommercialPredictor(MODEL_DIR) return _predictor except Exception as e: _load_error = str(e) logger.warning(f"Could not load model: {_load_error}") return None def model_is_available() -> bool: """Check if model loaded successfully.""" return _get_predictor() is not None def get_load_error() -> Optional[str]: """Get error message if model failed to load.""" _get_predictor() return _load_error def predict_potential(text: str, meta: Optional[dict] = None) -> Dict[str, Any]: """ Public API for scoring manuscripts. Returns dict with: prediction, probability, label, confidence, source """ predictor = _get_predictor() if predictor is None: return { "prediction": None, "probability": None, "label": None, "confidence": None, "source": "unavailable", } result = predictor.predict(text) result["source"] = "model" return result __all__ = ["predict_potential", "model_is_available", "get_load_error"]