MODLI's picture
Update app_simple.py
d30aeb9 verified
# app_simple.py
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
from transformers import ViTImageProcessor, ViTForImageClassification
import torch
import os # <-- AJOUT IMPORT OS
import logging # <-- AJOUT IMPORT LOGGING
# Configuration du logging pour debuguer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Detection Outfit API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Chargement du modèle ---
# Utilise un chemin absolu pour être sûr
model_path = "/app/model"
logger.info(f"Tentative de chargement du modèle depuis : {model_path}")
logger.info(f"Contenu du dossier model/ : {os.listdir(model_path) if os.path.exists(model_path) else 'DOSSIER INTROUVABLE'}")
try:
# Vérifie que le fichier de config essentiel existe
config_file = os.path.join(model_path, "preprocessor_config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Fichier de config introuvable: {config_file}")
processor = ViTImageProcessor.from_pretrained(model_path)
model = ViTForImageClassification.from_pretrained(model_path)
logger.info("Modèle et processeur chargés avec succès!")
except Exception as e:
logger.error(f"ERREUR FATALE lors du chargement du modèle: {e}")
# Il faut arrêter l'application si le modèle ne charge pas
raise e
# --- Définition des labels ---
# ⚠️ REMPLACE ÇA PAR LES VRAIES ÉTIQUETTES DE TON MODÈLE ! ⚠️
# Ouvre ton fichier /app/model/config.json et trouve la section "id2label"
id2label = {
"0": "T-shirt",
"1": "Pantalon",
"2": "Pull",
"3": "Robe",
"4": "Manteau",
"5": "Sandale",
"6": "Chemise",
"7": "Sneaker",
"8": "Sac",
"9": "Botte"
}
@app.get("/")
def read_root():
return {"message": "Bienvenue sur l'API de detection d'outfit!"}
@app.get("/health")
def health_check():
"""Endpoint de santé pour vérifier que l'API et le modèle sont chargés"""
return {
"status": "healthy",
"model_loaded": True,
"model_path": model_path
}
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="Le fichier doit être une image.")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_label = id2label[str(predicted_class_idx)]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
response = {
"predicted_label": predicted_label,
"confidence": round(confidence, 4)
}
return response
except Exception as e:
logger.error(f"Erreur lors de la classification: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de la classification: {str(e)}")