Spaces:
Paused
Paused
| import pickle | |
| import os | |
| import gdown | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| from typing import List, Dict, Any | |
| class MLInferenceService: | |
| """Service for loading and running ML model inference.""" | |
| def __init__(self, model_dir: str = "./model", gdrive_file_id: str = None): | |
| self.model_dir = model_dir | |
| self.gdrive_file_id = gdrive_file_id or os.getenv("GDRIVE_MODEL_ID") | |
| self.model = None | |
| self.tokenizer = None | |
| self.clf = None | |
| self.label_names = [] | |
| def load_model(self): | |
| """Load the model, tokenizer, and label names.""" | |
| if self.model is not None: | |
| return | |
| # If Google Drive ID provided, download model file | |
| if self.gdrive_file_id: | |
| self._download_from_gdrive() | |
| # Load model and tokenizer | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |
| # Load MultiLabelBinarizer to get label names | |
| with open(f"{self.model_dir}/mlb.pkl", "rb") as f: | |
| mlb = pickle.load(f) | |
| self.label_names = list(mlb.classes_) | |
| # Create pipeline for inference | |
| self.clf = pipeline( | |
| "text-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| return_all_scores=True | |
| ) | |
| def _download_from_gdrive(self): | |
| """Download model.safetensors from Google Drive if not exists.""" | |
| model_path = f"{self.model_dir}/model.safetensors" | |
| # Skip download if file already exists | |
| if os.path.exists(model_path): | |
| return | |
| # Ensure model directory exists | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| # Download from Google Drive | |
| print(f"Downloading model from Google Drive...") | |
| gdrive_url = f"https://drive.google.com/uc?id={self.gdrive_file_id}" | |
| gdown.download(gdrive_url, model_path, quiet=False) | |
| print(f"Model downloaded successfully to {model_path}") | |
| def predict(self, text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Predict labels for the given text. | |
| Args: | |
| text: Input text to classify | |
| Returns: | |
| List of dictionaries with 'label' and 'score' keys | |
| """ | |
| if self.clf is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| # Process text: replace ||| with [SEP] | |
| processed_text = text.replace('|||', '[SEP]') | |
| # Get predictions | |
| result = self.clf(processed_text) | |
| # Map label indices to label names and filter by score >= 0.5 | |
| output = [ | |
| {'label': self.label_names[i], 'score': item['score']} | |
| for i, item in enumerate(result[0]) | |
| if item['score'] >= 0.5 | |
| ] | |
| return output | |
| # Global singleton instance | |
| _ml_service = None | |
| def get_ml_service() -> MLInferenceService: | |
| """Get or create the global ML service instance.""" | |
| global _ml_service | |
| if _ml_service is None: | |
| _ml_service = MLInferenceService() | |
| _ml_service.load_model() | |
| return _ml_service | |