import gradio as gr from transformers import pipeline from PIL import Image import numpy as np import requests from io import BytesIO import base64 import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import uvicorn import torch import torchvision.transforms as transforms from torchvision.models import resnet50 import torch.nn as nn # Catégories fashion plus détaillées et précises FASHION_CATEGORIES = [ "t-shirt", "dress", "jeans", "jacket", "skirt", "sneakers", "handbag", "swimsuit", "lingerie", "sweater", "coat", "shorts", "blouse", "hat", "top", "sweatpants", "dress pants", "leggings", "boots", "sandals", "heels", "backpack", "sunglasses", "blazer", "cardigan", "polo shirt", "hoodie", "vest", "jumpsuit", "romper", "crop top", "tank top", "long sleeve shirt", "windbreaker", "parka", "trench coat", "leather jacket", "denim jacket", "waistcoat", "suit", "tie", "scarf", "gloves", "belt", "wallet", "watch", "jewelry" ] print("🔧 Loading fashion model...") # Charger un modèle plus spécialisé pour la mode try: # Essayer d'abord un modèle spécialisé fashion fashion_pipe = pipeline( "image-classification", model="nateraw/fashion-clip", device=0 if torch.cuda.is_available() else -1 ) print("✅ Fashion-CLIP model loaded successfully!") except: try: # Fallback sur un modèle plus général mais avec fine-tuning fashion_pipe = pipeline( "zero-shot-image-classification", model="openai/clip-vit-large-patch14", device=0 if torch.cuda.is_available() else -1 ) print("✅ CLIP Large model loaded successfully!") except: # Dernier recours fashion_pipe = pipeline( "zero-shot-image-classification", model="openai/clip-vit-base-patch32", device=0 if torch.cuda.is_available() else -1 ) print("✅ CLIP Base model loaded as fallback!") # Configuration API API_KEYS = os.environ.get("API_KEYS", "").split(",") # Modèle pour les requêtes API class ClassificationRequest(BaseModel): image_data: str api_key: Optional[str] = None def preprocess_image(image): """Prétraite l'image pour améliorer la détection""" # Conversion en RGB si nécessaire if image.mode != 'RGB': image = image.convert('RGB') # Redimensionnement intelligent avec maintien des proportions width, height = image.size max_size = 512 if max(width, height) > max_size: ratio = max_size / max(width, height) new_size = (int(width * ratio), int(height * ratio)) image = image.resize(new_size, Image.Resampling.LANCZOS) return image def load_image_from_url(url): """Charge une image depuis une URL de manière robuste""" try: headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers, timeout=15) response.raise_for_status() # Vérifie que c'est bien une image if not response.headers.get('content-type', '').startswith('image/'): raise ValueError("URL does not point to an image") image = Image.open(BytesIO(response.content)) return preprocess_image(image) except Exception as e: raise ValueError(f"❌ Cannot load image from URL: {str(e)}") def analyze_fashion_item(image_input, url_input): """Analyse des vêtements - supporte image upload et URL""" try: # Détermine quelle source utiliser if image_input is not None: # Priorité à l'image uploadée if isinstance(image_input, np.ndarray): image = Image.fromarray(image_input) else: image = image_input image = preprocess_image(image) elif url_input and url_input.strip(): # Utilise l'URL image = load_image_from_url(url_input.strip()) else: return "❌ Please upload an image or enter a URL first", None # 🔥 ANALYSE PRINCIPALE AVEC PARAMÈTRES OPTIMISÉS try: # Essayer d'abord avec le modèle fashion-clip predictions = fashion_pipe(image) # Si c'est le modèle fashion-clip, adapter le format de réponse if hasattr(fashion_pipe, 'model') and 'fashion-clip' in str(fashion_pipe.model): # Trier par score et formater predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) confident_predictions = [p for p in predictions if p['score'] > 0.05] else: # Pour les modèles zero-shot predictions = fashion_pipe( image, candidate_labels=FASHION_CATEGORIES, hypothesis_template="a clear photo of {}", multi_label=True ) confident_predictions = [p for p in predictions if p['score'] > 0.1] except Exception as model_error: print(f"Model error: {model_error}") return "❌ Model analysis failed. Please try another image.", image if not confident_predictions: return "❌ No confident prediction. Try a clearer image with better lighting.", image # Trier par score décroissant confident_predictions.sort(key=lambda x: x['score'], reverse=True) best_pred = confident_predictions[0] # Formatage des résultats result_text = f"🎯 **Main item**: {best_pred['label'].title()}\n" result_text += f"**Confidence**: {best_pred['score']*100:.1f}%\n\n" if len(confident_predictions) > 1: result_text += "**Other possibilities**:\n" for i, pred in enumerate(confident_predictions[1:6], 1): # Top 5 seulement result_text += f"{i}. {pred['label'].title()} ({pred['score']*100:.1f}%)\n" # Conseils basés sur la confiance if best_pred['score'] < 0.7: result_text += f"\n💡 **Tip**: Low confidence. Try a clearer image with the item centered and good lighting." else: result_text += f"\n✅ **High confidence detection**: This is very likely a {best_pred['label']}." return result_text, image except Exception as e: error_msg = str(e) if "URL" in error_msg: return f"❌ URL Error: {error_msg}", None else: return f"❌ Analysis Error: {error_msg}", None # ============================================================================= # INTERFACE GRADIO # ============================================================================= with gr.Blocks( title="Fashion AI Classifier", theme=gr.themes.Soft(), css=""" .container { max-width: 900px; margin: 0 auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; } .input-section { background: #f8f9fa; padding: 20px; border-radius: 10px; } .output-section { background: white; padding: 20px; border-radius: 10px; } .success { color: green; } .warning { color: orange; } .error { color: red; } """ ) as demo: with gr.Column(elem_classes="container"): gr.Markdown("""

👗 Fashion AI Classifier

Upload an image or enter URL to analyze clothing items

""") with gr.Row(): with gr.Column(elem_classes="input-section"): # Inputs séparés mais simples image_input = gr.Image( type="pil", label="Upload Fashion Item", height=200 ) url_input = gr.Textbox( label="Or enter Image URL", placeholder="https://example.com/image.jpg", lines=2 ) gr.Markdown(""" **📝 Tips for better results:** - Use clear, well-lit images - Center the clothing item - Use plain backgrounds when possible - Avoid multiple items in one image """) analyze_btn = gr.Button( "🔍 Analyze Item", variant="primary", size="lg" ) with gr.Column(elem_classes="output-section"): output_text = gr.Markdown( label="Analysis Results", value="**Results will appear here...**" ) output_image = gr.Image( label="Processed Image", interactive=False, height=300, show_download_button=False ) # Instructions gr.Markdown(""" ### 📋 Instructions: - **Upload** an image **OR** enter a direct image URL - Make sure the clothing item is clearly visible - Well-lit images work best - Avoid busy backgrounds - For best results, show one item at a time """) # Événement de click analyze_btn.click( fn=analyze_fashion_item, inputs=[image_input, url_input], outputs=[output_text, output_image] ) # ============================================================================= # API ENDPOINTS POUR LOVABLE # ============================================================================= app = FastAPI() @app.post("/api/classify") async def api_classify(request: ClassificationRequest): """Endpoint API pour Lovable avec support des clés API""" try: # Vérification de la clé API (si configurée) if API_KEYS and API_KEYS[0] and API_KEYS[0] != "": if not request.api_key or request.api_key not in API_KEYS: raise HTTPException(status_code=401, detail="Invalid API key") if not request.image_data: raise HTTPException(status_code=400, detail="No image data provided") # Décode l'image base64 if ',' in request.image_data: request.image_data = request.image_data.split(',')[1] image_bytes = base64.b64decode(request.image_data) image = Image.open(BytesIO(image_bytes)) image = preprocess_image(image) # Analyse avec des inputs vides pour URL result_text, processed_image = analyze_fashion_item(image, "") if result_text.startswith("❌"): raise HTTPException(status_code=400, detail=result_text) # Convertir l'image traitée en base64 pour la réponse buffered = BytesIO() processed_image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() return { "success": True, "result": result_text, "processed_image": f"data:image/jpeg;base64,{img_str}" } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"API Error: {str(e)}") @app.get("/api/health") async def health_check(): """Endpoint de santé pour vérifier que l'API fonctionne""" return {"status": "healthy", "model_loaded": True} # Montrer l'interface Gradio sur la racine app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=7860, reload=False )