import os import json import time os.environ['HF_HOME'] = '/tmp/cache' os.environ['TORCH_HOME'] = '/tmp/cache' from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from PIL import Image import torch import io import colorthief import tempfile import numpy as np app = FastAPI(title="Fashion Classification API") # Middleware CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"] ) # --- ÉTAT DU MODÈLE --- print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...") model = None processor = None model_loading = False model_loaded = False model_error = None def load_fashion_model(): global model, processor, model_loading, model_loaded, model_error model_loading = True try: from transformers import AutoModel, AutoProcessor model_name = "Marqo/Marqo-FashionSigLIP-Classification" print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)") # Charger le modèle SigLIP model = AutoModel.from_pretrained( model_name, cache_dir="/tmp/cache", torch_dtype=torch.float16, trust_remote_code=True ) processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !") model_loaded = True model_loading = False except Exception as e: print(f"❌ Erreur chargement modèle: {e}") model_error = str(e) model_loading = False import traceback traceback.print_exc() # Démarrer le chargement IMMÉDIATEMENT load_fashion_model() # Catégories de mode categories = [ "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers", "handbag", "jacket", "shorts", "sweater", "coat", "high heels", "blouse", "boots", "hat" ] @app.get("/") def read_root(): return { "message": "Fashion Classification API is running!", "status": "OK", "model_status": "loaded" if model_loaded else "loading" if model_loading else "error", "model_name": "Marqo-FashionSigLIP-Classification" } @app.get("/health") def health_check(): return { "model_loaded": model_loaded, "model_loading": model_loading, "model_error": model_error, "status": "ready" if model_loaded else "loading" if model_loading else "error", "model_name": "Marqo-FashionSigLIP-Classification", "timestamp": time.time() } @app.post("/analyze") async def analyze_image(file: UploadFile = File(...)): # Vérifier si le modèle est chargé if not model_loaded: if model_loading: raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health") else: raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}") if model is None or processor is None: raise HTTPException(status_code=500, detail="Model not available") try: # Lire et préparer l'image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") image = image.resize((384, 384)) # Traitement avec SigLIP inputs = processor( text=categories, images=image, return_tensors="pt", padding=True, truncation=True, max_length=64, ) device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = torch.sigmoid(logits_per_image) probs = probs.cpu().numpy()[0] predicted_idx = np.argmax(probs) category_name = categories[predicted_idx] confidence_score = float(probs[predicted_idx]) # Analyse couleur try: with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: image.save(tmp, format='JPEG') tmp_path = tmp.name color_thief = colorthief.ColorThief(tmp_path) dominant_color = color_thief.get_color(quality=1) hex_color = '#%02x%02x%02x' % dominant_color os.unlink(tmp_path) except Exception: hex_color = "#000000" return { "category": category_name, "confidence": round(confidence_score, 4), "color_hex": hex_color, "model": "Marqo-FashionSigLIP-Classification" } except Exception as e: raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}") # Interface de test avec statut de chargement @app.get("/test-ui", response_class=HTMLResponse) async def test_ui(): return f"""