Spaces:
Sleeping
Sleeping
| """ | |
| Model inference module for SlushSense. | |
| Loads the trained DistilBERT model and runs predictions. | |
| """ | |
| import json | |
| import logging | |
| import torch | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| logger = logging.getLogger(__name__) | |
| # Paths - relative to this file's location | |
| SRC_DIR = Path(__file__).resolve().parent | |
| BASE_DIR = SRC_DIR.parent | |
| MODELS_DIR = BASE_DIR / "models" | |
| MODEL_DIR = MODELS_DIR / "deberta_commercial_top25" | |
| # Model settings | |
| MODEL_NAME = "distilbert-base-uncased" | |
| MAX_LEN = 384 | |
| DEVICE = "cpu" | |
| # Singleton predictor | |
| _predictor: Optional["CommercialPredictor"] = None | |
| _load_error: Optional[str] = None | |
| class CommercialPredictor: | |
| """Wrapper for the trained DistilBERT classifier.""" | |
| def __init__(self, model_dir: Path): | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| self.device = torch.device(DEVICE) | |
| model_dir = Path(model_dir) | |
| # Load tokenizer | |
| if (model_dir / "tokenizer.json").exists(): | |
| self.tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Load model | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| str(model_dir), num_labels=1 | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Load threshold | |
| self.threshold = 0.5 | |
| thr_path = model_dir / "threshold.json" | |
| if thr_path.exists(): | |
| with open(thr_path) as f: | |
| self.threshold = json.load(f).get("threshold", 0.5) | |
| logger.info(f"Model loaded from {model_dir}, threshold={self.threshold}") | |
| def predict(self, text: str) -> Dict[str, Any]: | |
| """Run prediction on text.""" | |
| if not text or not text.strip(): | |
| return {"prediction": None, "probability": None, "label": None, "confidence": None} | |
| inputs = self.tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_LEN, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits.squeeze(-1) | |
| prob = torch.sigmoid(logits).cpu().item() | |
| pred = 1 if prob >= self.threshold else 0 | |
| label = "good" if pred == 1 else "average to low" | |
| confidence = abs(prob - self.threshold) / max(self.threshold, 1 - self.threshold) | |
| return { | |
| "prediction": pred, | |
| "probability": round(prob, 4), | |
| "label": label, | |
| "confidence": round(confidence, 4), | |
| } | |
| def _get_predictor() -> Optional[CommercialPredictor]: | |
| """Lazy-load the predictor singleton.""" | |
| global _predictor, _load_error | |
| if _predictor is not None: | |
| return _predictor | |
| if _load_error is not None: | |
| return None | |
| try: | |
| _predictor = CommercialPredictor(MODEL_DIR) | |
| return _predictor | |
| except Exception as e: | |
| _load_error = str(e) | |
| logger.warning(f"Could not load model: {_load_error}") | |
| return None | |
| def model_is_available() -> bool: | |
| """Check if model loaded successfully.""" | |
| return _get_predictor() is not None | |
| def get_load_error() -> Optional[str]: | |
| """Get error message if model failed to load.""" | |
| _get_predictor() | |
| return _load_error | |
| def predict_potential(text: str, meta: Optional[dict] = None) -> Dict[str, Any]: | |
| """ | |
| Public API for scoring manuscripts. | |
| Returns dict with: prediction, probability, label, confidence, source | |
| """ | |
| predictor = _get_predictor() | |
| if predictor is None: | |
| return { | |
| "prediction": None, | |
| "probability": None, | |
| "label": None, | |
| "confidence": None, | |
| "source": "unavailable", | |
| } | |
| result = predictor.predict(text) | |
| result["source"] = "model" | |
| return result | |
| __all__ = ["predict_potential", "model_is_available", "get_load_error"] | |