lovebird25 / ml_service.py
Paul
Initial commit
75146bf
import pickle
import os
import gdown
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from typing import List, Dict, Any
class MLInferenceService:
"""Service for loading and running ML model inference."""
def __init__(self, model_dir: str = "./model", gdrive_file_id: str = None):
self.model_dir = model_dir
self.gdrive_file_id = gdrive_file_id or os.getenv("GDRIVE_MODEL_ID")
self.model = None
self.tokenizer = None
self.clf = None
self.label_names = []
def load_model(self):
"""Load the model, tokenizer, and label names."""
if self.model is not None:
return
# If Google Drive ID provided, download model file
if self.gdrive_file_id:
self._download_from_gdrive()
# Load model and tokenizer
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
# Load MultiLabelBinarizer to get label names
with open(f"{self.model_dir}/mlb.pkl", "rb") as f:
mlb = pickle.load(f)
self.label_names = list(mlb.classes_)
# Create pipeline for inference
self.clf = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
return_all_scores=True
)
def _download_from_gdrive(self):
"""Download model.safetensors from Google Drive if not exists."""
model_path = f"{self.model_dir}/model.safetensors"
# Skip download if file already exists
if os.path.exists(model_path):
return
# Ensure model directory exists
os.makedirs(self.model_dir, exist_ok=True)
# Download from Google Drive
print(f"Downloading model from Google Drive...")
gdrive_url = f"https://drive.google.com/uc?id={self.gdrive_file_id}"
gdown.download(gdrive_url, model_path, quiet=False)
print(f"Model downloaded successfully to {model_path}")
def predict(self, text: str) -> List[Dict[str, Any]]:
"""
Predict labels for the given text.
Args:
text: Input text to classify
Returns:
List of dictionaries with 'label' and 'score' keys
"""
if self.clf is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
# Process text: replace ||| with [SEP]
processed_text = text.replace('|||', '[SEP]')
# Get predictions
result = self.clf(processed_text)
# Map label indices to label names and filter by score >= 0.5
output = [
{'label': self.label_names[i], 'score': item['score']}
for i, item in enumerate(result[0])
if item['score'] >= 0.5
]
return output
# Global singleton instance
_ml_service = None
def get_ml_service() -> MLInferenceService:
"""Get or create the global ML service instance."""
global _ml_service
if _ml_service is None:
_ml_service = MLInferenceService()
_ml_service.load_model()
return _ml_service