| """Shared model loading and inference for the API, Gradio and other clients.""" |
|
|
| import os |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DEFAULT_MODEL_ID = str(PROJECT_ROOT / "model" / "sentiment_model") |
| MAX_LENGTH = 256 |
|
|
|
|
| class SentimentAnalyzer: |
| """Load the fine-tuned model once and expose a reusable prediction method.""" |
|
|
| def __init__(self, model_id: str | None = None, device: str | None = None): |
| self.model_id = model_id or os.getenv("MODEL_ID", DEFAULT_MODEL_ID) |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer = None |
| self.model = None |
|
|
| @property |
| def is_loaded(self) -> bool: |
| return self.model is not None and self.tokenizer is not None |
|
|
| def load(self) -> "SentimentAnalyzer": |
| if self.is_loaded: |
| return self |
|
|
| model_path = Path(self.model_id) |
| if model_path.is_absolute() and not model_path.exists(): |
| raise FileNotFoundError( |
| f"Modele introuvable dans {model_path}. Lancez d'abord scripts/train.py." |
| ) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| self.model_id |
| ).to(self.device) |
| self.model.eval() |
| return self |
|
|
| def unload(self) -> None: |
| self.model = None |
| self.tokenizer = None |
|
|
| def predict(self, text: str, include_debug: bool = False) -> dict: |
| if not self.is_loaded: |
| raise RuntimeError("Le modele doit etre charge avant la prediction.") |
|
|
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=MAX_LENGTH, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(**inputs).logits[0] |
|
|
| probabilities = torch.softmax(logits, dim=-1) |
| predicted_class_id = int(torch.argmax(probabilities)) |
|
|
| result = { |
| "label": self.model.config.id2label[predicted_class_id], |
| "confidence": round(float(probabilities[predicted_class_id]), 4), |
| "probabilities": { |
| self.model.config.id2label[i]: round(float(probability), 4) |
| for i, probability in enumerate(probabilities) |
| }, |
| } |
| if include_debug: |
| result["debug"] = { |
| "tokens": self.tokenizer.convert_ids_to_tokens( |
| inputs["input_ids"][0].detach().cpu().tolist() |
| ), |
| "logits": [round(float(logit), 4) for logit in logits], |
| } |
| return result |
|
|