# app_simple.py from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io from transformers import ViTImageProcessor, ViTForImageClassification import torch import os # <-- AJOUT IMPORT OS import logging # <-- AJOUT IMPORT LOGGING # Configuration du logging pour debuguer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Detection Outfit API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Chargement du modèle --- # Utilise un chemin absolu pour être sûr model_path = "/app/model" logger.info(f"Tentative de chargement du modèle depuis : {model_path}") logger.info(f"Contenu du dossier model/ : {os.listdir(model_path) if os.path.exists(model_path) else 'DOSSIER INTROUVABLE'}") try: # Vérifie que le fichier de config essentiel existe config_file = os.path.join(model_path, "preprocessor_config.json") if not os.path.exists(config_file): raise RuntimeError(f"Fichier de config introuvable: {config_file}") processor = ViTImageProcessor.from_pretrained(model_path) model = ViTForImageClassification.from_pretrained(model_path) logger.info("Modèle et processeur chargés avec succès!") except Exception as e: logger.error(f"ERREUR FATALE lors du chargement du modèle: {e}") # Il faut arrêter l'application si le modèle ne charge pas raise e # --- Définition des labels --- # ⚠️ REMPLACE ÇA PAR LES VRAIES ÉTIQUETTES DE TON MODÈLE ! ⚠️ # Ouvre ton fichier /app/model/config.json et trouve la section "id2label" id2label = { "0": "T-shirt", "1": "Pantalon", "2": "Pull", "3": "Robe", "4": "Manteau", "5": "Sandale", "6": "Chemise", "7": "Sneaker", "8": "Sac", "9": "Botte" } @app.get("/") def read_root(): return {"message": "Bienvenue sur l'API de detection d'outfit!"} @app.get("/health") def health_check(): """Endpoint de santé pour vérifier que l'API et le modèle sont chargés""" return { "status": "healthy", "model_loaded": True, "model_path": model_path } @app.post("/classify") async def classify_image(file: UploadFile = File(...)): if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="Le fichier doit être une image.") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_label = id2label[str(predicted_class_idx)] confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item() response = { "predicted_label": predicted_label, "confidence": round(confidence, 4) } return response except Exception as e: logger.error(f"Erreur lors de la classification: {e}") raise HTTPException(status_code=500, detail=f"Erreur lors de la classification: {str(e)}")