Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import time | |
| import gc | |
| import pickle | |
| import tempfile | |
| import logging | |
| from typing import Optional | |
| import asyncio | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| from huggingface_hub import HfFolder, snapshot_download | |
| # Ensure HF cache is writable and not using /data | |
| import os as _os_env | |
| _os_env.environ.setdefault("HF_HOME", "/tmp/hf_home") | |
| _os_env.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf_home") | |
| # Avoid deprecated TRANSFORMERS_CACHE which may point to /data | |
| if "TRANSFORMERS_CACHE" in _os_env.environ: | |
| del _os_env.environ["TRANSFORMERS_CACHE"] | |
| _os_env.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") | |
| # Configuration du logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| APP_START_TS = time.time() | |
| # Configuration du modèle | |
| MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-3n-E4B-it") # Fixed model name | |
| DEVICE_MAP = os.environ.get("DEVICE_MAP", "cpu") # Force CPU pour Hugging Face Spaces | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "256")) | |
| # Fichier de cache pour partager le modèle entre Streamlit et FastAPI | |
| MODEL_CACHE_FILE = os.path.join(tempfile.gettempdir(), "agrilens_model_cache.pkl") | |
| def _get_dtype() -> torch.dtype: | |
| """Choix optimal du dtype selon le hardware.""" | |
| # Force float32 pour Hugging Face Spaces (CPU) | |
| return torch.float32 | |
| def _build_prompt(culture: Optional[str], notes: Optional[str]) -> str: | |
| """Création du prompt d'analyse.""" | |
| base = ( | |
| "You are an agronomy assistant. Analyze the provided plant leaf image and identify the most likely disease. " | |
| "Return a concise diagnosis in French with: disease name, short explanation of symptoms, " | |
| "and 3 actionable treatment recommendations." | |
| ) | |
| if culture: | |
| base += f"\nCulture: {culture}" | |
| if notes: | |
| base += f"\nNotes: {notes}" | |
| return base | |
| class SharedModelManager: | |
| """Gestionnaire de modèle partagé entre Streamlit et FastAPI""" | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| self.device_map = DEVICE_MAP | |
| self.dtype = _get_dtype() | |
| self._load_attempted = False | |
| self._loading = False | |
| self._load_error = None | |
| self._last_load_attempt = 0 | |
| self._load_timeout = 300 # 5 minutes timeout | |
| logger.info(f"Initializing ModelManager with device_map={self.device_map}, dtype={self.dtype}") | |
| # Try to recover from previous state | |
| self._recover_state() | |
| def _recover_state(self): | |
| """Try to recover model state from disk""" | |
| try: | |
| state_file = "/tmp/model_state.json" | |
| if os.path.exists(state_file): | |
| import json | |
| with open(state_file, 'r') as f: | |
| state = json.load(f) | |
| # Check if the state is recent (less than 1 hour old) | |
| if time.time() - state.get('timestamp', 0) < 3600: | |
| logger.info("État précédent trouvé, tentative de récupération...") | |
| # Note: We can't actually reload the model objects, but we can mark as attempted | |
| self._load_attempted = True | |
| self._last_load_attempt = state.get('timestamp', 0) | |
| except Exception as e: | |
| logger.warning(f"Impossible de récupérer l'état: {e}") | |
| def _save_state(self): | |
| """Save current state to disk""" | |
| try: | |
| state_file = "/tmp/model_state.json" | |
| import json | |
| state = { | |
| 'timestamp': time.time(), | |
| 'model_loaded': self.model is not None, | |
| 'processor_loaded': self.processor is not None, | |
| 'load_attempted': self._load_attempted, | |
| 'loading': self._loading, | |
| 'error': self._load_error | |
| } | |
| with open(state_file, 'w') as f: | |
| json.dump(state, f) | |
| except Exception as e: | |
| logger.warning(f"Impossible de sauvegarder l'état: {e}") | |
| def check_streamlit_model_cache(self): | |
| """Vérifie si le modèle est disponible dans le cache Streamlit via un fichier""" | |
| try: | |
| # Vérifier si le fichier de cache existe et est récent (moins de 1 heure) | |
| if os.path.exists(MODEL_CACHE_FILE): | |
| file_age = time.time() - os.path.getmtime(MODEL_CACHE_FILE) | |
| if file_age < 3600: # 1 heure | |
| # Lire les informations du cache | |
| try: | |
| with open(MODEL_CACHE_FILE, 'rb') as f: | |
| cache_data = pickle.load(f) | |
| logger.info(f"Cache Streamlit trouvé: {cache_data}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la lecture du cache: {e}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la vérification du cache: {e}") | |
| return False | |
| def load_model_directly(self): | |
| """Robust model loading that tries multiple approaches to avoid permission issues""" | |
| try: | |
| import gc | |
| self._loading = True | |
| self._load_attempted = True | |
| self._last_load_attempt = time.time() | |
| self._load_error = None | |
| # Try different approaches in order of preference | |
| approaches = [ | |
| ("Direct HF Hub loading", self._try_direct_loading), | |
| ("Cache in /app/cache", self._try_app_cache), | |
| ("Cache in /tmp/hf_home", self._try_tmp_cache), | |
| ("Cache in /tmp/model_repo", self._try_tmp_repo), | |
| ] | |
| for approach_name, approach_func in approaches: | |
| try: | |
| logger.info(f"Tentative: {approach_name}") | |
| success = approach_func() | |
| if success: | |
| self._loading = False | |
| self._save_state() | |
| logger.info(f"✅ Succès avec {approach_name}") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"❌ Échec de {approach_name}: {e}") | |
| continue | |
| # If all approaches failed | |
| self._loading = False | |
| self._load_error = "Toutes les approches de chargement ont échoué" | |
| self._save_state() | |
| return False | |
| except Exception as e: | |
| logger.error(f"Erreur critique chargement: {e}") | |
| self._loading = False | |
| self._load_error = str(e) | |
| self._save_state() | |
| return False | |
| def _try_direct_loading(self): | |
| """Try to load directly from Hugging Face Hub without using /data by forcing cache_dir""" | |
| try: | |
| logger.info("Chargement direct depuis HF Hub...") | |
| writable_cache = os.environ.get("HF_HOME", "/home/user/.cache/huggingface") | |
| os.makedirs(writable_cache, exist_ok=True) | |
| # Load processor directly with explicit cache_dir | |
| self.processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| cache_dir=writable_cache, | |
| local_files_only=False, | |
| ) | |
| logger.info("Processor chargé directement") | |
| # Load model directly with explicit cache_dir | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| cache_dir=writable_cache, | |
| local_files_only=False, | |
| low_cpu_mem_usage=True, | |
| device_map=self.device_map, | |
| torch_dtype=self.dtype, | |
| ) | |
| if self.device_map == "cpu": | |
| self.model = self.model.to("cpu") | |
| logger.info("Modèle chargé directement depuis HF Hub") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Échec chargement direct: {e}") | |
| return False | |
| def _try_app_cache(self): | |
| """Try to cache in /app/cache directory""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cache_dir = "/app/cache/huggingface" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info(f"Snapshot vers {cache_dir}") | |
| snapshot_download( | |
| repo_id=MODEL_ID, | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| token=os.environ.get("HF_TOKEN", None), | |
| ) | |
| # Load from cache | |
| self.processor = AutoProcessor.from_pretrained( | |
| cache_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| logger.info("Processor chargé depuis /app/cache") | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| cache_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| low_cpu_mem_usage=True, | |
| device_map=self.device_map, | |
| torch_dtype=self.dtype, | |
| ) | |
| if self.device_map == "cpu": | |
| self.model = self.model.to("cpu") | |
| logger.info("Modèle chargé depuis /app/cache") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Échec cache /app: {e}") | |
| return False | |
| def _try_tmp_cache(self): | |
| """Try to cache in /tmp/hf_home directory""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| cache_dir = "/tmp/hf_home" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info(f"Snapshot vers {cache_dir}") | |
| snapshot_download( | |
| repo_id=MODEL_ID, | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| token=os.environ.get("HF_TOKEN", None), | |
| ) | |
| # Load from cache | |
| self.processor = AutoProcessor.from_pretrained( | |
| cache_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| logger.info("Processor chargé depuis /tmp/hf_home") | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| cache_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| low_cpu_mem_usage=True, | |
| device_map=self.device_map, | |
| torch_dtype=self.dtype, | |
| ) | |
| if self.device_map == "cpu": | |
| self.model = self.model.to("cpu") | |
| logger.info("Modèle chargé depuis /tmp/hf_home") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Échec cache /tmp/hf_home: {e}") | |
| return False | |
| def _try_tmp_repo(self): | |
| """Try to cache in /tmp/model_repo directory (original approach)""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| repo_dir = "/tmp/model_repo" | |
| offload_dir = "/tmp/model_offload" | |
| os.makedirs(repo_dir, exist_ok=True) | |
| os.makedirs(offload_dir, exist_ok=True) | |
| logger.info(f"Snapshot vers {repo_dir}") | |
| snapshot_download( | |
| repo_id=MODEL_ID, | |
| local_dir=repo_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| token=os.environ.get("HF_TOKEN", None), | |
| ) | |
| # Load from cache | |
| self.processor = AutoProcessor.from_pretrained( | |
| repo_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| ) | |
| logger.info("Processor chargé depuis /tmp/model_repo") | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| repo_dir, | |
| trust_remote_code=True, | |
| local_files_only=True, | |
| low_cpu_mem_usage=True, | |
| device_map=self.device_map, | |
| torch_dtype=self.dtype, | |
| offload_folder=offload_dir, | |
| max_memory={0: "8GB", "cpu": "8GB"} if self.device_map == "cpu" else None, | |
| ) | |
| if self.device_map == "cpu": | |
| self.model = self.model.to("cpu") | |
| logger.info("Modèle chargé depuis /tmp/model_repo") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Échec cache /tmp/model_repo: {e}") | |
| return False | |
| def load_model_with_retry(self, max_retries=5, delay=60): | |
| """Charge le modèle avec retry automatique en cas d'échec""" | |
| for attempt in range(max_retries): | |
| try: | |
| logger.info(f"Tentative de chargement {attempt + 1}/{max_retries}") | |
| success = self.load_model_directly() | |
| if success: | |
| return True | |
| else: | |
| logger.warning(f"Échec tentative {attempt + 1}, attente {delay}s...") | |
| if attempt < max_retries - 1: | |
| time.sleep(delay) | |
| except Exception as e: | |
| logger.error(f"Erreur tentative {attempt + 1}: {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(delay) | |
| logger.error(f"Toutes les {max_retries} tentatives ont échoué") | |
| return False | |
| def ensure_model_loaded(self): | |
| """S'assure que le modèle est chargé""" | |
| if self.model is not None and self.processor is not None: | |
| return True | |
| if not self._load_attempted: | |
| self._load_attempted = True | |
| # Charge directement le modèle (lancé à la demande) | |
| return self.load_model_directly() | |
| return False | |
| def get_load_status(self): | |
| """Retourne le statut de chargement""" | |
| return { | |
| "loaded": self.model is not None and self.processor is not None, | |
| "loading": self._loading, | |
| "error": self._load_error, | |
| "attempted": self._load_attempted | |
| } | |
| def _complete_partial_load(self): | |
| """Complete a partial model load (when processor is loaded but model is not)""" | |
| try: | |
| logger.info("Tentative de complétion du chargement partiel...") | |
| if self.processor and not self.model: | |
| logger.info("Processor disponible, chargement du modèle seulement...") | |
| # Try to load just the model using the existing processor | |
| try: | |
| # Use the processor's config to load the model | |
| model_config = self.processor.config | |
| model_path = model_config._name_or_path | |
| logger.info(f"Chargement du modèle depuis {model_path}") | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| device_map=self.device_map, | |
| torch_dtype=self.dtype, | |
| offload_folder="/tmp/model_offload", | |
| max_memory={0: "8GB", "cpu": "8GB"} if self.device_map == "cpu" else None | |
| ) | |
| if self.device_map == "cpu": | |
| self.model = self.model.to("cpu") | |
| logger.info("Modèle complété avec succès!") | |
| self._loading = False | |
| self._save_state() | |
| return True | |
| except Exception as e: | |
| logger.error(f"Échec de la complétion: {e}") | |
| # Fall back to full reload | |
| return self.load_model_directly() | |
| else: | |
| logger.info("Pas de chargement partiel à compléter") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la complétion: {e}") | |
| return False | |
| # Instance globale du gestionnaire de modèle | |
| model_manager = SharedModelManager() | |
| app = FastAPI(title="AgriLens AI FastAPI", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Warmup non bloquant au démarrage - use a more robust approach | |
| async def _warmup_background(): | |
| """Démarrage du chargement en arrière-plan sans bloquer le serveur""" | |
| logger.info("Démarrage du chargement du modèle en arrière-plan...") | |
| # Use a more robust approach that won't be cancelled | |
| try: | |
| # Run in thread but don't await it to avoid cancellation | |
| import threading | |
| thread = threading.Thread(target=model_manager.load_model_directly, daemon=True) | |
| thread.start() | |
| logger.info("Thread de chargement démarré") | |
| except Exception as e: | |
| logger.error(f"Erreur lors du démarrage du thread: {e}") | |
| # Alternative: also try to load on first request if not already loaded | |
| async def ensure_model_loaded_middleware(request, call_next): | |
| """Middleware pour s'assurer que le modèle est chargé avec récupération automatique""" | |
| try: | |
| current_time = time.time() | |
| # Check for partial loads and trigger automatic recovery (with rate limiting) | |
| if (model_manager.processor and not model_manager.model and | |
| not model_manager._loading and | |
| not hasattr(model_manager, '_middleware_recovery_triggered')): | |
| logger.info("🔧 Récupération automatique déclenchée via middleware") | |
| model_manager._middleware_recovery_triggered = current_time | |
| # Start recovery in background | |
| import threading | |
| thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
| thread.start() | |
| # Check if model needs loading (with rate limiting) | |
| elif (not model_manager.model and not model_manager._loading and | |
| not hasattr(model_manager, '_middleware_load_triggered')): | |
| logger.info("Modèle non chargé, tentative de chargement...") | |
| model_manager._middleware_load_triggered = current_time | |
| # Start loading in background | |
| import threading | |
| thread = threading.Thread(target=model_manager.load_model_directly, daemon=True) | |
| thread.start() | |
| # Clean up old triggers (older than 5 minutes) | |
| if hasattr(model_manager, '_middleware_recovery_triggered'): | |
| if current_time - model_manager._middleware_recovery_triggered > 300: | |
| delattr(model_manager, '_middleware_recovery_triggered') | |
| if hasattr(model_manager, '_middleware_load_triggered'): | |
| if current_time - model_manager._middleware_load_triggered > 300: | |
| delattr(model_manager, '_middleware_load_triggered') | |
| except Exception as e: | |
| logger.error(f"Erreur dans le middleware: {e}") | |
| response = await call_next(request) | |
| return response | |
| # Add a background task that keeps trying to load the model | |
| async def _persistent_model_loader(): | |
| """Persistent model loader that keeps trying until success""" | |
| import asyncio | |
| import threading | |
| def _load_loop(): | |
| """Infinite loop to keep trying to load the model""" | |
| max_attempts = 5 # Maximum attempts before giving up | |
| attempt_count = 0 | |
| last_attempt_time = 0 | |
| cooldown = 60 # Wait 60s between attempts | |
| while attempt_count < max_attempts: | |
| try: | |
| current_time = time.time() | |
| # Check if we should attempt loading | |
| if (not model_manager.model and | |
| not model_manager._loading and | |
| current_time - last_attempt_time > cooldown): | |
| logger.info(f"Persistent loader: tentative {attempt_count + 1}/{max_attempts}...") | |
| last_attempt_time = current_time | |
| attempt_count += 1 | |
| success = model_manager.load_model_directly() | |
| if success: | |
| logger.info("Persistent loader: modèle chargé avec succès!") | |
| break | |
| else: | |
| logger.warning(f"Persistent loader: échec {attempt_count}/{max_attempts}, nouvelle tentative dans {cooldown}s...") | |
| time.sleep(cooldown) | |
| else: | |
| # Model is loading or loaded, wait a bit | |
| time.sleep(10) | |
| except Exception as e: | |
| logger.error(f"Persistent loader: erreur: {e}") | |
| attempt_count += 1 | |
| time.sleep(cooldown) | |
| if attempt_count >= max_attempts: | |
| logger.warning("Persistent loader: nombre maximum de tentatives atteint, arrêt") | |
| else: | |
| logger.info("Persistent loader: terminé avec succès") | |
| # Start the persistent loader in a daemon thread | |
| thread = threading.Thread(target=_load_loop, daemon=True) | |
| thread.start() | |
| logger.info("Persistent model loader démarré") | |
| # Add automated recovery system | |
| async def _automated_recovery(): | |
| """Automated recovery system that detects and fixes partial loads""" | |
| import threading | |
| import time | |
| def _recovery_loop(): | |
| """Continuous monitoring and recovery loop""" | |
| last_recovery_attempt = 0 | |
| recovery_cooldown = 60 # Wait 60s between recovery attempts | |
| while True: | |
| try: | |
| current_time = time.time() | |
| # Check for partial loads (processor loaded but model not) | |
| if (model_manager.processor and not model_manager.model and | |
| not model_manager._loading and | |
| current_time - last_recovery_attempt > recovery_cooldown): | |
| logger.info("🔧 Récupération automatique détectée: processor chargé mais modèle manquant") | |
| logger.info("🚀 Lancement automatique de la récupération...") | |
| last_recovery_attempt = current_time | |
| # Try to complete the partial load | |
| success = model_manager._complete_partial_load() | |
| if success: | |
| logger.info("✅ Récupération automatique réussie!") | |
| break # Exit the loop if successful | |
| else: | |
| logger.warning("⚠️ Récupération automatique échouée, nouvelle tentative dans 60s...") | |
| # Check for stuck loading states | |
| elif (model_manager._loading and | |
| current_time - model_manager._last_load_attempt > 300): # 5 minutes timeout | |
| logger.warning("⏰ Timeout détecté, reset de l'état de chargement...") | |
| model_manager._loading = False | |
| model_manager._load_error = "Timeout - chargement bloqué" | |
| model_manager._save_state() | |
| # Wait before next check | |
| time.sleep(15) # Check every 15 seconds | |
| except Exception as e: | |
| logger.error(f"Erreur dans la boucle de récupération: {e}") | |
| time.sleep(30) | |
| # Start the automated recovery in a daemon thread | |
| thread = threading.Thread(target=_recovery_loop, daemon=True) | |
| thread.start() | |
| logger.info("🔧 Système de récupération automatique démarré") | |
| # Add a more robust startup approach using a separate process | |
| async def _robust_startup(): | |
| """Robust startup using a separate process to avoid CancelledError""" | |
| import multiprocessing | |
| import time | |
| # Only start if not already loading | |
| if model_manager._loading: | |
| logger.info("Démarrage robuste: chargement déjà en cours, skip") | |
| return | |
| try: | |
| logger.info("Démarrage du chargement du modèle en arrière-plan...") | |
| # Set a flag to prevent multiple processes | |
| if hasattr(model_manager, '_startup_process_running'): | |
| logger.info("Processus de démarrage déjà en cours, skip") | |
| return | |
| model_manager._startup_process_running = True | |
| def _startup_load(): | |
| """Load model in separate process""" | |
| try: | |
| # Set environment for this process | |
| os.environ['HF_HOME'] = '/tmp/hf_home' | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_home/transformers' | |
| logger.info("Processus de chargement démarré") | |
| success = model_manager.load_model_directly() | |
| if success: | |
| logger.info("Processus: chargement réussi") | |
| else: | |
| logger.warning("Processus: échec du chargement") | |
| except Exception as e: | |
| logger.error(f"Processus: erreur: {e}") | |
| finally: | |
| # Clean up | |
| if hasattr(model_manager, '_startup_process_running'): | |
| delattr(model_manager, '_startup_process_running') | |
| # Start the process | |
| process = multiprocessing.Process(target=_startup_load, daemon=True) | |
| process.start() | |
| logger.info(f"Processus de chargement du modèle démarré (PID: {process.pid})") | |
| # Wait a bit for the process to start | |
| time.sleep(2) | |
| # Check if process is still alive | |
| if not process.is_alive(): | |
| logger.warning("Processus de démarrage s'est terminé prématurément") | |
| if hasattr(model_manager, '_startup_process_running'): | |
| delattr(model_manager, '_startup_process_running') | |
| except Exception as e: | |
| logger.error(f"Erreur lors du démarrage du processus: {e}") | |
| if hasattr(model_manager, '_startup_process_running'): | |
| delattr(model_manager, '_startup_process_running') | |
| # Add health monitoring with automatic recovery | |
| def health(): | |
| """Vérifie l'état de l'application et du modèle avec récupération automatique.""" | |
| try: | |
| # Check for partial loads and trigger automatic recovery | |
| if model_manager.processor and not model_manager.model and not model_manager._loading: | |
| logger.info("🔧 Récupération automatique déclenchée via /health") | |
| # Start recovery in background | |
| import threading | |
| thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
| thread.start() | |
| model_loaded = model_manager.ensure_model_loaded() | |
| streamlit_cache_available = model_manager.check_streamlit_model_cache() | |
| load_status = model_manager.get_load_status() | |
| return { | |
| "status": "ok" if model_loaded else "cold", | |
| "uptime_s": int(time.time() - APP_START_TS), | |
| "cuda": torch.cuda.is_available(), | |
| "device_map": model_manager.device_map, | |
| "dtype": str(model_manager.dtype), | |
| "model_id": MODEL_ID, | |
| "streamlit_cache_available": streamlit_cache_available, | |
| "model_loaded": model_loaded, | |
| "load_status": load_status, | |
| "auto_recovery": "active", | |
| } | |
| except Exception as e: | |
| logger.error(f"Erreur dans health check: {e}") | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "uptime_s": int(time.time() - APP_START_TS), | |
| } | |
| def load(): | |
| """Force le chargement du modèle.""" | |
| try: | |
| success = model_manager.load_model_directly() | |
| load_status = model_manager.get_load_status() | |
| if success: | |
| return {"status": "success", "message": "Modèle chargé avec succès", "load_status": load_status} | |
| else: | |
| return { | |
| "status": "error", | |
| "message": "Échec du chargement du modèle", | |
| "load_status": load_status, | |
| "error": model_manager._load_error | |
| } | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement forcé: {e}") | |
| return {"status": "error", "message": f"Erreur: {str(e)}"} | |
| async def diagnose( | |
| image: UploadFile = File(...), | |
| culture: Optional[str] = Form(None), | |
| notes: Optional[str] = Form(None) | |
| ): | |
| """Analyse une image de feuille de plante.""" | |
| try: | |
| # Vérifier que le modèle est chargé | |
| if not model_manager.ensure_model_loaded(): | |
| load_status = model_manager.get_load_status() | |
| if model_manager._loading: | |
| raise HTTPException(status_code=503, detail="Modèle en cours de chargement, veuillez réessayer dans quelques secondes") | |
| else: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Modèle non disponible. Statut: {load_status}" | |
| ) | |
| # Lire l'image | |
| image_data = await image.read() | |
| pil_image = Image.open(io.BytesIO(image_data)) | |
| # Préparer le prompt | |
| prompt = _build_prompt(culture, notes) | |
| # Préparer les entrées pour le modèle | |
| inputs = model_manager.processor( | |
| images=pil_image, | |
| text=prompt, | |
| return_tensors="pt" | |
| ) | |
| # Déplacer sur le bon device | |
| if model_manager.device_map == "cpu": | |
| inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
| # Générer la réponse | |
| with torch.no_grad(): | |
| outputs = model_manager.model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=model_manager.processor.tokenizer.eos_token_id | |
| ) | |
| # Décoder la réponse | |
| response_text = model_manager.processor.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Extraire seulement la partie générée (après le prompt) | |
| if prompt in response_text: | |
| diagnosis = response_text.split(prompt)[-1].strip() | |
| else: | |
| diagnosis = response_text.strip() | |
| return { | |
| "diagnosis": diagnosis, | |
| "model_id": MODEL_ID, | |
| "culture": culture, | |
| "notes": notes, | |
| "processing_time": time.time() - APP_START_TS | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Erreur lors du diagnostic: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de l'analyse: {str(e)}") | |
| def recover(): | |
| """Tente de récupérer un chargement partiel du modèle.""" | |
| try: | |
| if model_manager.processor and not model_manager.model: | |
| logger.info("Récupération d'un chargement partiel...") | |
| success = model_manager._complete_partial_load() | |
| if success: | |
| return {"status": "success", "message": "Modèle récupéré avec succès"} | |
| else: | |
| return {"status": "error", "message": "Échec de la récupération"} | |
| else: | |
| return {"status": "info", "message": "Pas de chargement partiel à récupérer"} | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la récupération: {e}") | |
| return {"status": "error", "message": f"Erreur: {str(e)}"} | |
| def detailed_status(): | |
| """Statut détaillé du système avec informations de récupération automatique""" | |
| try: | |
| current_time = time.time() | |
| # Calculate time since last load attempt | |
| time_since_last_attempt = current_time - model_manager._last_load_attempt if model_manager._last_load_attempt > 0 else 0 | |
| # Check for various states | |
| partial_load_detected = model_manager.processor and not model_manager.model | |
| stuck_loading = model_manager._loading and time_since_last_attempt > 300 | |
| recovery_needed = partial_load_detected or stuck_loading | |
| status_info = { | |
| "timestamp": current_time, | |
| "model_state": { | |
| "processor_loaded": model_manager.processor is not None, | |
| "model_loaded": model_manager.model is not None, | |
| "loading": model_manager._loading, | |
| "load_attempted": model_manager._load_attempted, | |
| "time_since_last_attempt": f"{time_since_last_attempt:.1f}s" | |
| }, | |
| "auto_recovery": { | |
| "active": True, | |
| "partial_load_detected": partial_load_detected, | |
| "stuck_loading_detected": stuck_loading, | |
| "recovery_needed": recovery_needed, | |
| "check_interval": "15s" | |
| }, | |
| "system": { | |
| "uptime_s": int(current_time - APP_START_TS), | |
| "device_map": model_manager.device_map, | |
| "dtype": str(model_manager.dtype), | |
| "model_id": MODEL_ID | |
| } | |
| } | |
| # If recovery is needed, trigger it automatically | |
| if recovery_needed: | |
| logger.info("🔧 Récupération automatique déclenchée via /status") | |
| if partial_load_detected: | |
| import threading | |
| thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
| thread.start() | |
| elif stuck_loading: | |
| model_manager._loading = False | |
| model_manager._load_error = "Timeout - chargement bloqué" | |
| model_manager._save_state() | |
| return status_info | |
| except Exception as e: | |
| logger.error(f"Erreur dans detailed_status: {e}") | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "timestamp": time.time() | |
| } | |
| def root(): | |
| """Page d'accueil avec informations sur l'API.""" | |
| return { | |
| "message": "AgriLens AI FastAPI", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "load": "/load", | |
| "diagnose": "/diagnose (POST)" | |
| }, | |
| "model": MODEL_ID, | |
| "uptime_s": int(time.time() - APP_START_TS) | |
| } | |
| # Lancement correct pour Hugging Face Spaces | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) # Hugging Face donne ce port | |
| uvicorn.run("app:app", host="0.0.0.0", port=port) | |