Spaces:
Running
Running
| """ | |
| API REST FastAPI pour les prédictions de rendement agricole. | |
| Cette API charge un modèle MLflow (pipeline sklearn) au démarrage et expose des endpoints | |
| pour effectuer des prédictions de rendement (hg/ha) à partir de variables explicatives. | |
| """ | |
| import asyncio | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import Dict, List | |
| import httpx | |
| import joblib | |
| import numpy as np | |
| import logfire | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from config import FEATURE_COLUMNS, TARGET_COLUMN, MODEL_REL_PATH, DATA_REL_PATH, DEFAULT_FRONTEND_URL | |
| # Charger les variables d'environnement depuis .env | |
| load_dotenv() | |
| # ======================================================================================================= | |
| # Configuration Logfire (cloud) | |
| # ======================================================================================================= | |
| logfire.configure( | |
| token=os.environ.get("LOGFIRE_TOKEN"), | |
| service_name="crop-yield-api", | |
| send_to_logfire="if-token-present", | |
| ) | |
| # ======================================================================================================= | |
| # Chemins et variables globales | |
| # ======================================================================================================= | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # URL du frontend à réveiller toutes les 12h (variable d'env prioritaire, sinon valeur de config.py) | |
| FRONTEND_URL = os.environ.get("FRONTEND_URL", DEFAULT_FRONTEND_URL) | |
| # Chemin vers le modèle MLflow (pipeline sklearn complet avec préprocesseur) | |
| MODEL_PATH = os.path.join(BASE_DIR, MODEL_REL_PATH) | |
| # Chemin vers le fichier d'entraînement nettoyé (pour récupérer les cultures et pays disponibles) | |
| DATA_PATH = os.path.join(BASE_DIR, DATA_REL_PATH) | |
| # Variables globales chargées au démarrage | |
| pipeline = None # Pipeline sklearn (préprocesseur + modèle) | |
| available_items: List[str] = [] # Cultures disponibles | |
| available_areas: List[str] = [] # Pays disponibles | |
| available_items_per_area: Dict[str, List[str]] = {} # Cultures disponibles par pays | |
| # ======================================================================================================= | |
| # Chargement des ressources au démarrage | |
| # ======================================================================================================= | |
| def load_pipeline(): | |
| """Charge le pipeline sklearn depuis le fichier model.pkl.""" | |
| try: | |
| model_file = os.path.join(MODEL_PATH, "model.pkl") | |
| model = joblib.load(model_file) | |
| logfire.info("Pipeline chargé avec succès depuis {path}", path=model_file) | |
| return model | |
| except Exception as e: | |
| logfire.error("Erreur lors du chargement du pipeline: {error}", error=str(e)) | |
| return None | |
| def load_training_data(): | |
| """Charge le CSV d'entraînement pour récupérer les listes de cultures et pays.""" | |
| try: | |
| df = pd.read_csv(DATA_PATH, sep=";", usecols=["Item", "Area"]) | |
| missing_columns = {"Item", "Area"} - set(df.columns) | |
| if missing_columns: | |
| raise ValueError( | |
| f"Colonnes manquantes dans le fichier d'entraînement: {sorted(missing_columns)}" | |
| ) | |
| items = sorted(df["Item"].dropna().unique().tolist()) | |
| areas = sorted(df["Area"].dropna().unique().tolist()) | |
| # Construire le mapping cultures par pays | |
| items_area_map = {} | |
| for area in areas: | |
| area_items = sorted(df[df["Area"] == area]["Item"].dropna().unique().tolist()) | |
| items_area_map[area] = area_items | |
| logfire.info("Données d'entraînement chargées: {n_items} cultures, {n_areas} pays", n_items=len(items), n_areas=len(areas)) | |
| return items, areas, items_area_map | |
| except FileNotFoundError: | |
| logfire.error("Fichier de données introuvable: {path}", path=DATA_PATH) | |
| return [], [], {} | |
| except Exception as e: | |
| logfire.error("Erreur lors du chargement des données: {error}", error=str(e)) | |
| return [], [], {} | |
| async def lifespan(app: FastAPI): | |
| """Chargement des ressources au démarrage de l'API.""" | |
| global pipeline, available_items, available_areas, available_items_per_area | |
| pipeline = load_pipeline() | |
| available_items, available_areas, available_items_per_area = load_training_data() | |
| logfire.info("API initialisée et prête") | |
| # Lancement de la tâche keep-alive en arrière-plan | |
| task = asyncio.create_task(keep_alive_task()) | |
| yield | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| logfire.info("API arrêtée") | |
| async def keep_alive_task(): | |
| """ | |
| Tâche de keep-alive : ping le frontend toutes les 12h | |
| afin de le maintenir actif sur les plateformes d'hébergement avec mise en veille. | |
| """ | |
| INTERVAL = 12 * 3600 # 12 heures en secondes | |
| await asyncio.sleep(60) # Délai initial pour laisser le temps au démarrage complet | |
| while True: | |
| if FRONTEND_URL: | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| resp = await client.get(FRONTEND_URL, timeout=15) | |
| logfire.info("Keep-alive frontend: réponse {status}", status=resp.status_code) | |
| except Exception as e: | |
| logfire.warning("Keep-alive frontend échoué: {error}", error=str(e)) | |
| else: | |
| logfire.warning("Keep-alive: FRONTEND_URL non défini, aucun ping envoyé.") | |
| await asyncio.sleep(INTERVAL) | |
| # ======================================================================================================= | |
| # Initialisation de l'application FastAPI + instrumentation Logfire | |
| # ======================================================================================================= | |
| app = FastAPI( | |
| title="API de Prédiction de Rendement Agricole", | |
| description="API pour prédire le rendement agricole (hg/ha) en fonction de la culture, du pays, de l'année et de variables climatiques.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| logfire.instrument_fastapi(app) | |
| # ======================================================================================================= | |
| # Modèles Pydantic pour la validation des entrées / sorties | |
| # ======================================================================================================= | |
| class PredictInput(BaseModel): | |
| """Données d'entrée pour une prédiction de rendement avec culture spécifiée.""" | |
| Area: str = Field(..., description="Pays (ex: 'France', 'Albania')") | |
| Item: str = Field(..., description="Culture (ex: 'Wheat', 'Maize')") | |
| Year: int = Field(..., ge=1990, le=2040, description="Année") | |
| average_rain_fall_mm_per_year: float = Field(..., ge=40, le=4000, description="Précipitations moyennes annuelles (mm)") | |
| pesticides_tonnes: float = Field(..., ge=0, le=400000, description="Quantité de pesticides utilisés (tonnes)") | |
| avg_temp: float = Field(..., gt=0, le=35, description="Température moyenne (°C)") | |
| class PredictionOutput(BaseModel): | |
| """Résultat d'une prédiction de rendement.""" | |
| Area: str = Field(..., description="Pays") | |
| Item: str = Field(..., description="Culture") | |
| Year: int = Field(..., description="Année") | |
| predicted_yield: float = Field(..., description="Rendement prédit (hg/ha)") | |
| class RecommendInput(BaseModel): | |
| """Données d'entrée pour la recommandation de cultures (sans Item).""" | |
| Area: str = Field(..., description="Pays (ex: 'France', 'Albania')") | |
| Year: int = Field(..., ge=1990, le=2040, description="Année") | |
| average_rain_fall_mm_per_year: float = Field(..., ge=40, le=4000, description="Précipitations moyennes annuelles (mm)") | |
| pesticides_tonnes: float = Field(..., ge=0, le=400000, description="Quantité de pesticides utilisés (tonnes)") | |
| avg_temp: float = Field(..., gt=0, le=35, description="Température moyenne (°C)") | |
| class RecommendOutput(BaseModel): | |
| """Résultat de la recommandation : prédictions pour toutes les cultures, triées par rendement décroissant.""" | |
| area: str | |
| year: int | |
| recommendations: List[PredictionOutput] | |
| status: str = "success" | |
| # ======================================================================================================= | |
| # Endpoints | |
| # ======================================================================================================= | |
| def health_check(): | |
| """Vérification de l'état de santé de l'API.""" | |
| logfire.info("Health check") | |
| return { | |
| "status": "ok", | |
| "model_loaded": pipeline is not None, | |
| "available_items": len(available_items), | |
| "available_areas": len(available_areas), | |
| } | |
| def get_columns(): | |
| """Retourne la liste des colonnes (features) attendues par le modèle.""" | |
| logfire.info("Colonnes demandées") | |
| return { | |
| "columns": FEATURE_COLUMNS, | |
| "target": TARGET_COLUMN, | |
| "available_items": available_items, | |
| "available_areas": available_areas, | |
| "available_items_per_area": available_items_per_area, | |
| } | |
| def predict(input_data: PredictInput): | |
| """ | |
| Prédiction de rendement pour une culture et un ensemble de variables explicatives. | |
| """ | |
| logfire.info( | |
| "Requête /predict reçue: {area} / {item} / {year}", | |
| area=input_data.Area, | |
| item=input_data.Item, | |
| year=input_data.Year, | |
| ) | |
| if pipeline is None: | |
| logfire.error("Pipeline non chargé") | |
| raise HTTPException(status_code=503, detail="Modèle non chargé. Réessayez plus tard.") | |
| # Validation métier : vérifier que le pays est connu | |
| if input_data.Area not in available_areas: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Pays '{input_data.Area}' inconnu. Consultez GET /columns pour la liste des pays.", | |
| ) | |
| # Validation métier : vérifier que la culture est produite dans le pays sélectionné | |
| if input_data.Item not in available_items_per_area.get(input_data.Area, []): | |
| available_for_area = available_items_per_area.get(input_data.Area, []) | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Culture '{input_data.Item}' non produite dans '{input_data.Area}'. Cultures disponibles pour ce pays : {available_for_area}", | |
| ) | |
| try: | |
| df = pd.DataFrame([input_data.model_dump()]) | |
| # Appliquer la transformation log1p sur pesticides_tonnes (cohérence avec l'EDA / entraînement) | |
| df["pesticides_tonnes"] = np.log1p(df["pesticides_tonnes"]) | |
| prediction = pipeline.predict(df)[0] | |
| logfire.info("Prédiction effectuée: {yield_pred:.2f} hg/ha", yield_pred=float(prediction)) | |
| return PredictionOutput( | |
| Area=input_data.Area, | |
| Item=input_data.Item, | |
| Year=input_data.Year, | |
| predicted_yield=round(float(prediction), 2), | |
| ) | |
| except Exception as e: | |
| logfire.error("Erreur lors de la prédiction: {error}", error=str(e)) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def recommend(input_data: RecommendInput): | |
| """ | |
| Recommandation de cultures : prédit le rendement pour les cultures disponibles | |
| dans le pays sélectionné et les renvoie classées par rendement décroissant. | |
| """ | |
| logfire.info( | |
| "Requête /recommend reçue: {area} / {year}", | |
| area=input_data.Area, | |
| year=input_data.Year, | |
| ) | |
| if pipeline is None: | |
| logfire.error("Pipeline non chargé") | |
| raise HTTPException(status_code=503, detail="Modèle non chargé. Réessayez plus tard.") | |
| if input_data.Area not in available_areas: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Pays '{input_data.Area}' inconnu. Consultez GET /columns pour la liste des pays.", | |
| ) | |
| area_items = available_items_per_area.get(input_data.Area, []) | |
| if not area_items: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Aucune culture disponible pour le pays '{input_data.Area}'.", | |
| ) | |
| try: | |
| # Construire un DataFrame avec une ligne par culture disponible dans le pays | |
| base_data = input_data.model_dump() | |
| rows = [] | |
| for item in area_items: | |
| row = {**base_data, "Item": item} | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| # Appliquer la transformation log1p sur pesticides_tonnes (cohérence avec l'EDA / entraînement) | |
| df["pesticides_tonnes"] = np.log1p(df["pesticides_tonnes"]) | |
| predictions = pipeline.predict(df) | |
| # Associer chaque culture à sa prédiction et trier par rendement décroissant | |
| results = [ | |
| PredictionOutput(Area=input_data.Area, Item=item, Year=input_data.Year, predicted_yield=round(float(pred), 2)) | |
| for item, pred in zip(area_items, predictions) | |
| ] | |
| results.sort(key=lambda x: x.predicted_yield, reverse=True) | |
| logfire.info( | |
| "Recommandation effectuée: meilleure culture = {best} ({yield_pred:.2f} hg/ha)", | |
| best=results[0].Item, | |
| yield_pred=results[0].predicted_yield, | |
| ) | |
| return RecommendOutput(area=input_data.Area, year=input_data.Year, recommendations=results) | |
| except Exception as e: | |
| logfire.error("Erreur lors de la recommandation: {error}", error=str(e)) | |
| raise HTTPException(status_code=400, detail=str(e)) | |