Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| os.environ['HF_HOME'] = '/tmp/cache' | |
| os.environ['TORCH_HOME'] = '/tmp/cache' | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from PIL import Image | |
| import torch | |
| import io | |
| import colorthief | |
| import tempfile | |
| import numpy as np | |
| app = FastAPI(title="Fashion Classification API") | |
| # Middleware CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # --- ÉTAT DU MODÈLE --- | |
| print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...") | |
| model = None | |
| processor = None | |
| model_loading = False | |
| model_loaded = False | |
| model_error = None | |
| def load_fashion_model(): | |
| global model, processor, model_loading, model_loaded, model_error | |
| model_loading = True | |
| try: | |
| from transformers import AutoModel, AutoProcessor | |
| model_name = "Marqo/Marqo-FashionSigLIP-Classification" | |
| print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)") | |
| # Charger le modèle SigLIP | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| cache_dir="/tmp/cache", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !") | |
| model_loaded = True | |
| model_loading = False | |
| except Exception as e: | |
| print(f"❌ Erreur chargement modèle: {e}") | |
| model_error = str(e) | |
| model_loading = False | |
| import traceback | |
| traceback.print_exc() | |
| # Démarrer le chargement IMMÉDIATEMENT | |
| load_fashion_model() | |
| # Catégories de mode | |
| categories = [ | |
| "t-shirt", "dress", "jeans", "shirt", "skirt", | |
| "sneakers", "handbag", "jacket", "shorts", "sweater", | |
| "coat", "high heels", "blouse", "boots", "hat" | |
| ] | |
| def read_root(): | |
| return { | |
| "message": "Fashion Classification API is running!", | |
| "status": "OK", | |
| "model_status": "loaded" if model_loaded else "loading" if model_loading else "error", | |
| "model_name": "Marqo-FashionSigLIP-Classification" | |
| } | |
| def health_check(): | |
| return { | |
| "model_loaded": model_loaded, | |
| "model_loading": model_loading, | |
| "model_error": model_error, | |
| "status": "ready" if model_loaded else "loading" if model_loading else "error", | |
| "model_name": "Marqo-FashionSigLIP-Classification", | |
| "timestamp": time.time() | |
| } | |
| async def analyze_image(file: UploadFile = File(...)): | |
| # Vérifier si le modèle est chargé | |
| if not model_loaded: | |
| if model_loading: | |
| raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health") | |
| else: | |
| raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}") | |
| if model is None or processor is None: | |
| raise HTTPException(status_code=500, detail="Model not available") | |
| try: | |
| # Lire et préparer l'image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| image = image.resize((384, 384)) | |
| # Traitement avec SigLIP | |
| inputs = processor( | |
| text=categories, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=64, | |
| ) | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = torch.sigmoid(logits_per_image) | |
| probs = probs.cpu().numpy()[0] | |
| predicted_idx = np.argmax(probs) | |
| category_name = categories[predicted_idx] | |
| confidence_score = float(probs[predicted_idx]) | |
| # Analyse couleur | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: | |
| image.save(tmp, format='JPEG') | |
| tmp_path = tmp.name | |
| color_thief = colorthief.ColorThief(tmp_path) | |
| dominant_color = color_thief.get_color(quality=1) | |
| hex_color = '#%02x%02x%02x' % dominant_color | |
| os.unlink(tmp_path) | |
| except Exception: | |
| hex_color = "#000000" | |
| return { | |
| "category": category_name, | |
| "confidence": round(confidence_score, 4), | |
| "color_hex": hex_color, | |
| "model": "Marqo-FashionSigLIP-Classification" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}") | |
| # Interface de test avec statut de chargement | |
| async def test_ui(): | |
| return f""" | |
| <html> | |
| <head> | |
| <title>FashionSigLIP Detection</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 40px; }} | |
| .container {{ max-width: 600px; margin: 0 auto; }} | |
| form {{ border: 2px dashed #ccc; padding: 30px; text-align: center; }} | |
| .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }} | |
| .loading {{ background: #fff3cd; color: #856404; }} | |
| .ready {{ background: #d4edda; color: #155724; }} | |
| .error {{ background: #f8d7da; color: #721c24; }} | |
| </style> | |
| <script> | |
| function checkStatus() {{ | |
| fetch('/health') | |
| .then(response => response.json()) | |
| .then(data => {{ | |
| const statusDiv = document.getElementById('model-status'); | |
| const submitBtn = document.getElementById('submit-btn'); | |
| if (data.model_loaded) {{ | |
| statusDiv.innerHTML = '✅ <b>Modèle chargé et prêt !</b>'; | |
| statusDiv.className = 'status ready'; | |
| submitBtn.disabled = false; | |
| }} else if (data.model_loading) {{ | |
| statusDiv.innerHTML = '⏳ <b>Chargement du modèle en cours...</b><br>Cela peut prendre 5-10 minutes'; | |
| statusDiv.className = 'status loading'; | |
| submitBtn.disabled = true; | |
| setTimeout(checkStatus, 5000); // Re-check dans 5 sec | |
| }} else {{ | |
| statusDiv.innerHTML = '❌ <b>Erreur de chargement:</b><br>' + (data.model_error || 'Unknown error'); | |
| statusDiv.className = 'status error'; | |
| submitBtn.disabled = true; | |
| }} | |
| }}); | |
| }} | |
| // Vérifier le statut au chargement de la page | |
| window.onload = checkStatus; | |
| </script> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>👗 FashionSigLIP Detector</h1> | |
| <div id="model-status" class="status loading"> | |
| Vérification du statut du modèle... | |
| </div> | |
| <form action="/analyze" method="post" enctype="multipart/form-data"> | |
| <h3>Uploader une image de vêtement :</h3> | |
| <input type="file" name="file" accept="image/*" required> | |
| <br><br> | |
| <input type="submit" id="submit-btn" value="Analyser" disabled> | |
| </form> | |
| <div style="margin-top: 20px;"> | |
| <button onclick="checkStatus()">Actualiser le statut</button> | |
| <button onclick="location.reload()">Rafraîchir la page</button> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ |