Spaces:
Build error
Build error
| # app_simple.py | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| import torch | |
| import os # <-- AJOUT IMPORT OS | |
| import logging # <-- AJOUT IMPORT LOGGING | |
| # Configuration du logging pour debuguer | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Detection Outfit API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Chargement du modèle --- | |
| # Utilise un chemin absolu pour être sûr | |
| model_path = "/app/model" | |
| logger.info(f"Tentative de chargement du modèle depuis : {model_path}") | |
| logger.info(f"Contenu du dossier model/ : {os.listdir(model_path) if os.path.exists(model_path) else 'DOSSIER INTROUVABLE'}") | |
| try: | |
| # Vérifie que le fichier de config essentiel existe | |
| config_file = os.path.join(model_path, "preprocessor_config.json") | |
| if not os.path.exists(config_file): | |
| raise RuntimeError(f"Fichier de config introuvable: {config_file}") | |
| processor = ViTImageProcessor.from_pretrained(model_path) | |
| model = ViTForImageClassification.from_pretrained(model_path) | |
| logger.info("Modèle et processeur chargés avec succès!") | |
| except Exception as e: | |
| logger.error(f"ERREUR FATALE lors du chargement du modèle: {e}") | |
| # Il faut arrêter l'application si le modèle ne charge pas | |
| raise e | |
| # --- Définition des labels --- | |
| # ⚠️ REMPLACE ÇA PAR LES VRAIES ÉTIQUETTES DE TON MODÈLE ! ⚠️ | |
| # Ouvre ton fichier /app/model/config.json et trouve la section "id2label" | |
| id2label = { | |
| "0": "T-shirt", | |
| "1": "Pantalon", | |
| "2": "Pull", | |
| "3": "Robe", | |
| "4": "Manteau", | |
| "5": "Sandale", | |
| "6": "Chemise", | |
| "7": "Sneaker", | |
| "8": "Sac", | |
| "9": "Botte" | |
| } | |
| def read_root(): | |
| return {"message": "Bienvenue sur l'API de detection d'outfit!"} | |
| def health_check(): | |
| """Endpoint de santé pour vérifier que l'API et le modèle sont chargés""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "model_path": model_path | |
| } | |
| async def classify_image(file: UploadFile = File(...)): | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="Le fichier doit être une image.") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert('RGB') | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_label = id2label[str(predicted_class_idx)] | |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item() | |
| response = { | |
| "predicted_label": predicted_label, | |
| "confidence": round(confidence, 4) | |
| } | |
| return response | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la classification: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la classification: {str(e)}") |