""" API ultra-légère pour la prédiction de tweets de catastrophe. Charge le modèle localement via transformers. Idéal pour Hugging Face Spaces (qui offre ~16Go de RAM). """ import os import re from contextlib import asynccontextmanager from functools import lru_cache from typing import Dict, List, Optional import emoji from fastapi import FastAPI, HTTPException from pydantic import BaseModel try: from transformers import pipeline except ImportError: pipeline = None # --- CONFIGURATION --- try: from pathlib import Path from dotenv import load_dotenv _env_path = Path(__file__).parent / ".env" load_dotenv(dotenv_path=_env_path) except ImportError: pass HF_MODEL_ID = os.getenv("HF_MODEL_ID", "Oscarkaf/disaster-tweets-bert") @asynccontextmanager async def lifespan(app: FastAPI): """Précharge le modèle au démarrage de l'API pour éviter la latence à la première requête.""" print("Démarrage de l'API, tentative de préchargement du modèle...") get_classifier() yield app = FastAPI( title="Disaster Tweet BERT API (Local HF Space)", description="API chargeant le modèle localement avec transformers.", version="3.0.0", lifespan=lifespan, ) # --- MODEL LOADING --- _classifier_pipeline = None _model_load_error = None def get_classifier(): """Charge paresseusement le modèle avec transformers.""" global _classifier_pipeline, _model_load_error if _classifier_pipeline is not None: return _classifier_pipeline if pipeline is None: _model_load_error = "La bibliothèque transformers n'est pas installée." return None try: from transformers import AutoTokenizer # Le tokenizer sauvegardé avec le modèle peut être corrompu ou mal configuré. # On force l'utilisation du tokenizer officiel de BERTweet avec la normalisation activée. print("Loading official tokenizer: vinai/bertweet-base...") tokenizer = AutoTokenizer.from_pretrained( "vinai/bertweet-base", normalization=True ) # On charge VOTRE modèle (cerveau) mais avec le TOKENIZER officiel (lunettes) print(f"Loading your model {HF_MODEL_ID} locally...") _classifier_pipeline = pipeline( "text-classification", model=HF_MODEL_ID, tokenizer=tokenizer ) return _classifier_pipeline except Exception as e: _model_load_error = str(e) print(f"Error loading model: {_model_load_error}") return None # --- TRANSLATION UTILS (LRU CACHE) --- try: from deep_translator import GoogleTranslator except ImportError: GoogleTranslator = None @lru_cache(maxsize=128) def translate_text(text: str) -> Dict[str, str]: res = {"translated_text": text, "detected_lang": "auto", "is_translated": False} if not text or not GoogleTranslator: return res try: translated = GoogleTranslator(source="auto", target="en").translate(text) if translated and translated.strip().lower() != text.strip().lower(): res["translated_text"] = translated.strip() res["is_translated"] = True except Exception as e: print(f"Translation exception: {e}") return res # --- TEXT CLEANING --- def clean_text_advanced(text: str) -> str: text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) text = re.sub(r"\@\w+", "[USER]", text) text = emoji.demojize(text) text = text.replace(":", " ").replace("_", " ") text = re.sub(r"\s+", " ", text).strip() return text # --- SCHEMAS --- class TweetInput(BaseModel): text: str location: Optional[str] = None keyword: Optional[str] = None class PredictionOutput(BaseModel): is_disaster: bool confidence: float clean_text: str model_name: str impact_words: Dict[str, float] detected_lang: str translated_text: str # --- CORE LOGIC --- def extract_disaster_confidence_from_pipeline(preds) -> float: """Extrait la probabilité de catastrophe à partir des résultats du pipeline.""" # Pipeline text-classification retourne souvent [{'label': '...', 'score': 0.9}] # Si le modèle est multi-label, ça peut être une liste plus complexe. # On va standardiser positive_labels = {"LABEL_1", "POSITIVE", "DISASTER", "1"} negative_labels = {"LABEL_0", "NEGATIVE", "NOT_DISASTER", "0"} # Cas d'un seul élément if isinstance(preds, list) and len(preds) > 0 and isinstance(preds[0], dict): label = str(preds[0].get("label", "")).upper() score = float(preds[0].get("score", 0.0)) if label in positive_labels: return score elif label in negative_labels: return 1.0 - score else: # Si le label n'est pas reconnu, on retourne le score brut return score return 0.5 def query_model(text: str) -> tuple: """Utilise le modèle local pour obtenir une prédiction.""" classifier = get_classifier() if not classifier: return None, _model_load_error or "Modèle non initialisé." try: preds = classifier(text) return extract_disaster_confidence_from_pipeline(preds), None except Exception as exc: return None, str(exc) def query_model_batch(texts: List[str]) -> List[float]: """Prédit sur un batch de textes.""" if not texts: return [] classifier = get_classifier() if not classifier: return [0.0] * len(texts) try: preds_list = classifier(texts) # preds_list est une liste de [{'label': ..., 'score': ...}] return [extract_disaster_confidence_from_pipeline([p]) for p in preds_list] except Exception: # Fallback to sequential results = [] for t in texts: conf, err = query_model(t) results.append(conf if conf is not None else 0.0) return results def heuristic_prediction(text: str) -> float: disaster_terms = { "earthquake", "flood", "wildfire", "fire", "hurricane", "evacuation", "disaster", "collapsed", "injured", "dead", "tsunami", "explosion", "rescue", "storm", "collision", "crash", "emergency", "alert", } words = set(re.findall(r"\w+", text.lower())) matches = words.intersection(disaster_terms) if matches: return min(0.9, 0.4 + (len(matches) * 0.15)) return 0.15 def explain_prediction(text: str, base_confidence: float) -> Dict[str, float]: words = text.split() if not words: return {} # Limiter à 10 mots pour la rapidité words_to_test = words[:10] variations = [] for i in range(len(words_to_test)): ablated = " ".join(words_to_test[:i] + words_to_test[i + 1 :]) variations.append(ablated if ablated.strip() else "[EMPTY]") ablated_confidences = query_model_batch(variations) impacts = {} for i, word in enumerate(words_to_test): conf = ( ablated_confidences[i] if i < len(ablated_confidences) else base_confidence ) impacts[word] = round(base_confidence - conf, 4) return dict(sorted(impacts.items(), key=lambda x: abs(x[1]), reverse=True)) # --- ENDPOINTS --- @app.get("/") def home(): return { "message": "API BERT (Local Transformers) v3.0.0 active. /docs pour tester." } @app.get("/health") def health(): classifier = get_classifier() status = "ok" if classifier is not None else "error" return { "status": status, "mode": "local_transformers", "model_loaded": classifier is not None, "model_name": "BERTweet (Local via Transformers)", "model_error": _model_load_error, "hf_model_id": HF_MODEL_ID, } @app.post("/predict", response_model=PredictionOutput) def predict_tweet(tweet: TweetInput): # 1. Traduction trans_res = translate_text(tweet.text) work_text = trans_res["translated_text"] # 2. Nettoyage cleaned_text = clean_text_advanced(work_text) # 3. Texte vide → Erreur 400 if not cleaned_text: raise HTTPException(status_code=400, detail="Le texte du tweet ne peut pas être vide.") # 4. Prédiction via pipeline local ou fallback confidence, error = query_model(cleaned_text) if error: confidence = heuristic_prediction(cleaned_text) model_used = f"Heuristic Fallback ({error[:80]})" impact_words: Dict[str, float] = {} else: model_used = "BERTweet (Local via Transformers)" # 5. Explicabilité (importance des mots) impact_words = explain_prediction(cleaned_text, confidence) return PredictionOutput( is_disaster=confidence >= 0.5, confidence=confidence, clean_text=cleaned_text, model_name=model_used, impact_words=impact_words, detected_lang=trans_res["detected_lang"], translated_text=work_text, )