Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| os.environ['HF_HOME'] = '/tmp/cache' | |
| os.environ['TORCH_HOME'] = '/tmp/cache' | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from PIL import Image | |
| import torch | |
| import io | |
| import colorthief | |
| import tempfile | |
| app = FastAPI(title="Fashion Detection API") | |
| # Middleware CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # --- CHARGE LE MODÈLE MARQO FASHIONCLIP --- | |
| print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...") | |
| model = None | |
| processor = None | |
| def load_marqo_model(): | |
| global model, processor | |
| try: | |
| from transformers import CLIPProcessor, CLIPModel | |
| model_name = "Marqo/marqo-fashionCLIP" | |
| model = CLIPModel.from_pretrained( | |
| model_name, | |
| cache_dir="/tmp/cache", | |
| torch_dtype=torch.float16 | |
| ) | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| print("✅ Modèle Marqo fashionCLIP chargé avec succès !") | |
| except Exception as e: | |
| print(f"❌ Erreur chargement modèle Marqo: {e}") | |
| async def startup_event(): | |
| import threading | |
| thread = threading.Thread(target=load_marqo_model) | |
| thread.daemon = True | |
| thread.start() | |
| # Catégories fashion (textes plus courts et uniformes) | |
| categories = [ | |
| "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers", | |
| "handbag", "jacket", "shorts", "sweater", "coat", "heels" | |
| ] | |
| def read_root(): | |
| return {"message": "Fashion Detection API is running!", "status": "OK"} | |
| def health_check(): | |
| return { | |
| "model_loaded": model is not None, | |
| "processor_loaded": processor is not None, | |
| "status": "ready" if model and processor else "loading" | |
| } | |
| async def analyze_image(file: UploadFile = File(...)): | |
| if model is None or processor is None: | |
| return {"error": "Model not loaded yet. Please wait or check /health endpoint."} | |
| try: | |
| # Lire l'image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Réduire la taille | |
| image.thumbnail((384, 384)) | |
| # --- SOLUTION DÉFINITIVE --- | |
| # Traiter chaque catégorie SÉPARÉMENT pour éviter les problèmes de padding | |
| similarities = [] | |
| for category in categories: | |
| # Préparer les inputs pour UNE catégorie à la fois | |
| inputs = processor( | |
| text=[category], # Une seule catégorie | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, # Padding pour une seule phrase | |
| truncation=True | |
| ) | |
| # Déplacer sur le device du modèle | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Récupérer le score de similarité | |
| similarity_score = outputs.logits_per_image.item() | |
| similarities.append(similarity_score) | |
| # Convertir en tensor et calculer les probabilités | |
| similarities_tensor = torch.tensor(similarities) | |
| probs = torch.nn.functional.softmax(similarities_tensor, dim=0) | |
| # Trouver la catégorie prédite | |
| predicted_class_idx = probs.argmax().item() | |
| category_name = categories[predicted_class_idx] | |
| confidence_score = probs[predicted_class_idx].item() | |
| # Analyse couleur | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: | |
| image.save(tmp, format='JPEG') | |
| tmp_path = tmp.name | |
| color_thief = colorthief.ColorThief(tmp_path) | |
| dominant_color = color_thief.get_color(quality=1) | |
| hex_color = '#%02x%02x%02x' % dominant_color | |
| os.unlink(tmp_path) | |
| except Exception as color_error: | |
| print(f"Erreur analyse couleur: {color_error}") | |
| hex_color = "#000000" | |
| return { | |
| "category": category_name, | |
| "color_hex": hex_color, | |
| "confidence": round(confidence_score, 4) | |
| } | |
| except Exception as e: | |
| return {"error": f"Erreur lors de l'analyse: {str(e)}"} | |
| # Interface de test | |
| async def test_ui(): | |
| return """ | |
| <html> | |
| <head> | |
| <title>Fashion Detection Test</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 40px; } | |
| .container { max-width: 600px; margin: 0 auto; } | |
| form { border: 2px dashed #ccc; padding: 30px; text-align: center; } | |
| input[type="file"] { margin: 10px 0; } | |
| input[type="submit"] { | |
| background: #007bff; color: white; padding: 10px 20px; | |
| border: none; cursor: pointer; border-radius: 5px; | |
| } | |
| .result { margin-top: 20px; padding: 20px; background: #f0f8ff; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>🎨 Fashion Detection AI</h1> | |
| <form action="/analyze" method="post" enctype="multipart/form-data"> | |
| <h3>Uploader une image de vêtement :</h3> | |
| <input type="file" name="file" accept="image/*" required> | |
| <br><br> | |
| <input type="submit" value="Analyser l'image 👗"> | |
| </form> | |
| <div class="result"> | |
| <h3>📋 Résultat de l'analyse :</h3> | |
| <p>Attendez l'upload et le traitement de l'image...</p> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ |