slushsense / src /model_inference.py
csiboto's picture
Create src/model_inference.py
92e2e6e verified
"""
Model inference module for SlushSense.
Loads the trained DistilBERT model and runs predictions.
"""
import json
import logging
import torch
from pathlib import Path
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
# Paths - relative to this file's location
SRC_DIR = Path(__file__).resolve().parent
BASE_DIR = SRC_DIR.parent
MODELS_DIR = BASE_DIR / "models"
MODEL_DIR = MODELS_DIR / "deberta_commercial_top25"
# Model settings
MODEL_NAME = "distilbert-base-uncased"
MAX_LEN = 384
DEVICE = "cpu"
# Singleton predictor
_predictor: Optional["CommercialPredictor"] = None
_load_error: Optional[str] = None
class CommercialPredictor:
"""Wrapper for the trained DistilBERT classifier."""
def __init__(self, model_dir: Path):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
self.device = torch.device(DEVICE)
model_dir = Path(model_dir)
# Load tokenizer
if (model_dir / "tokenizer.json").exists():
self.tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
else:
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model
self.model = AutoModelForSequenceClassification.from_pretrained(
str(model_dir), num_labels=1
)
self.model.to(self.device)
self.model.eval()
# Load threshold
self.threshold = 0.5
thr_path = model_dir / "threshold.json"
if thr_path.exists():
with open(thr_path) as f:
self.threshold = json.load(f).get("threshold", 0.5)
logger.info(f"Model loaded from {model_dir}, threshold={self.threshold}")
def predict(self, text: str) -> Dict[str, Any]:
"""Run prediction on text."""
if not text or not text.strip():
return {"prediction": None, "probability": None, "label": None, "confidence": None}
inputs = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LEN,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits.squeeze(-1)
prob = torch.sigmoid(logits).cpu().item()
pred = 1 if prob >= self.threshold else 0
label = "good" if pred == 1 else "average to low"
confidence = abs(prob - self.threshold) / max(self.threshold, 1 - self.threshold)
return {
"prediction": pred,
"probability": round(prob, 4),
"label": label,
"confidence": round(confidence, 4),
}
def _get_predictor() -> Optional[CommercialPredictor]:
"""Lazy-load the predictor singleton."""
global _predictor, _load_error
if _predictor is not None:
return _predictor
if _load_error is not None:
return None
try:
_predictor = CommercialPredictor(MODEL_DIR)
return _predictor
except Exception as e:
_load_error = str(e)
logger.warning(f"Could not load model: {_load_error}")
return None
def model_is_available() -> bool:
"""Check if model loaded successfully."""
return _get_predictor() is not None
def get_load_error() -> Optional[str]:
"""Get error message if model failed to load."""
_get_predictor()
return _load_error
def predict_potential(text: str, meta: Optional[dict] = None) -> Dict[str, Any]:
"""
Public API for scoring manuscripts.
Returns dict with: prediction, probability, label, confidence, source
"""
predictor = _get_predictor()
if predictor is None:
return {
"prediction": None,
"probability": None,
"label": None,
"confidence": None,
"source": "unavailable",
}
result = predictor.predict(text)
result["source"] = "model"
return result
__all__ = ["predict_potential", "model_is_available", "get_load_error"]