Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| import base64 | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import uvicorn | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.models import resnet50 | |
| import torch.nn as nn | |
| # Catégories fashion plus détaillées et précises | |
| FASHION_CATEGORIES = [ | |
| "t-shirt", "dress", "jeans", "jacket", "skirt", | |
| "sneakers", "handbag", "swimsuit", "lingerie", "sweater", | |
| "coat", "shorts", "blouse", "hat", "top", | |
| "sweatpants", "dress pants", "leggings", "boots", | |
| "sandals", "heels", "backpack", "sunglasses", "blazer", | |
| "cardigan", "polo shirt", "hoodie", "vest", "jumpsuit", | |
| "romper", "crop top", "tank top", "long sleeve shirt", | |
| "windbreaker", "parka", "trench coat", "leather jacket", | |
| "denim jacket", "waistcoat", "suit", "tie", "scarf", | |
| "gloves", "belt", "wallet", "watch", "jewelry" | |
| ] | |
| print("🔧 Loading fashion model...") | |
| # Charger un modèle plus spécialisé pour la mode | |
| try: | |
| # Essayer d'abord un modèle spécialisé fashion | |
| fashion_pipe = pipeline( | |
| "image-classification", | |
| model="nateraw/fashion-clip", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("✅ Fashion-CLIP model loaded successfully!") | |
| except: | |
| try: | |
| # Fallback sur un modèle plus général mais avec fine-tuning | |
| fashion_pipe = pipeline( | |
| "zero-shot-image-classification", | |
| model="openai/clip-vit-large-patch14", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("✅ CLIP Large model loaded successfully!") | |
| except: | |
| # Dernier recours | |
| fashion_pipe = pipeline( | |
| "zero-shot-image-classification", | |
| model="openai/clip-vit-base-patch32", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("✅ CLIP Base model loaded as fallback!") | |
| # Configuration API | |
| API_KEYS = os.environ.get("API_KEYS", "").split(",") | |
| # Modèle pour les requêtes API | |
| class ClassificationRequest(BaseModel): | |
| image_data: str | |
| api_key: Optional[str] = None | |
| def preprocess_image(image): | |
| """Prétraite l'image pour améliorer la détection""" | |
| # Conversion en RGB si nécessaire | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Redimensionnement intelligent avec maintien des proportions | |
| width, height = image.size | |
| max_size = 512 | |
| if max(width, height) > max_size: | |
| ratio = max_size / max(width, height) | |
| new_size = (int(width * ratio), int(height * ratio)) | |
| image = image.resize(new_size, Image.Resampling.LANCZOS) | |
| return image | |
| def load_image_from_url(url): | |
| """Charge une image depuis une URL de manière robuste""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers, timeout=15) | |
| response.raise_for_status() | |
| # Vérifie que c'est bien une image | |
| if not response.headers.get('content-type', '').startswith('image/'): | |
| raise ValueError("URL does not point to an image") | |
| image = Image.open(BytesIO(response.content)) | |
| return preprocess_image(image) | |
| except Exception as e: | |
| raise ValueError(f"❌ Cannot load image from URL: {str(e)}") | |
| def analyze_fashion_item(image_input, url_input): | |
| """Analyse des vêtements - supporte image upload et URL""" | |
| try: | |
| # Détermine quelle source utiliser | |
| if image_input is not None: | |
| # Priorité à l'image uploadée | |
| if isinstance(image_input, np.ndarray): | |
| image = Image.fromarray(image_input) | |
| else: | |
| image = image_input | |
| image = preprocess_image(image) | |
| elif url_input and url_input.strip(): | |
| # Utilise l'URL | |
| image = load_image_from_url(url_input.strip()) | |
| else: | |
| return "❌ Please upload an image or enter a URL first", None | |
| # 🔥 ANALYSE PRINCIPALE AVEC PARAMÈTRES OPTIMISÉS | |
| try: | |
| # Essayer d'abord avec le modèle fashion-clip | |
| predictions = fashion_pipe(image) | |
| # Si c'est le modèle fashion-clip, adapter le format de réponse | |
| if hasattr(fashion_pipe, 'model') and 'fashion-clip' in str(fashion_pipe.model): | |
| # Trier par score et formater | |
| predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) | |
| confident_predictions = [p for p in predictions if p['score'] > 0.05] | |
| else: | |
| # Pour les modèles zero-shot | |
| predictions = fashion_pipe( | |
| image, | |
| candidate_labels=FASHION_CATEGORIES, | |
| hypothesis_template="a clear photo of {}", | |
| multi_label=True | |
| ) | |
| confident_predictions = [p for p in predictions if p['score'] > 0.1] | |
| except Exception as model_error: | |
| print(f"Model error: {model_error}") | |
| return "❌ Model analysis failed. Please try another image.", image | |
| if not confident_predictions: | |
| return "❌ No confident prediction. Try a clearer image with better lighting.", image | |
| # Trier par score décroissant | |
| confident_predictions.sort(key=lambda x: x['score'], reverse=True) | |
| best_pred = confident_predictions[0] | |
| # Formatage des résultats | |
| result_text = f"🎯 **Main item**: {best_pred['label'].title()}\n" | |
| result_text += f"**Confidence**: {best_pred['score']*100:.1f}%\n\n" | |
| if len(confident_predictions) > 1: | |
| result_text += "**Other possibilities**:\n" | |
| for i, pred in enumerate(confident_predictions[1:6], 1): # Top 5 seulement | |
| result_text += f"{i}. {pred['label'].title()} ({pred['score']*100:.1f}%)\n" | |
| # Conseils basés sur la confiance | |
| if best_pred['score'] < 0.7: | |
| result_text += f"\n💡 **Tip**: Low confidence. Try a clearer image with the item centered and good lighting." | |
| else: | |
| result_text += f"\n✅ **High confidence detection**: This is very likely a {best_pred['label']}." | |
| return result_text, image | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "URL" in error_msg: | |
| return f"❌ URL Error: {error_msg}", None | |
| else: | |
| return f"❌ Analysis Error: {error_msg}", None | |
| # ============================================================================= | |
| # INTERFACE GRADIO | |
| # ============================================================================= | |
| with gr.Blocks( | |
| title="Fashion AI Classifier", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .container { max-width: 900px; margin: 0 auto; padding: 20px; } | |
| .header { text-align: center; margin-bottom: 30px; } | |
| .input-section { background: #f8f9fa; padding: 20px; border-radius: 10px; } | |
| .output-section { background: white; padding: 20px; border-radius: 10px; } | |
| .success { color: green; } | |
| .warning { color: orange; } | |
| .error { color: red; } | |
| """ | |
| ) as demo: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown(""" | |
| <div class='header'> | |
| <h1>👗 Fashion AI Classifier</h1> | |
| <p>Upload an image or enter URL to analyze clothing items</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(elem_classes="input-section"): | |
| # Inputs séparés mais simples | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Fashion Item", | |
| height=200 | |
| ) | |
| url_input = gr.Textbox( | |
| label="Or enter Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| lines=2 | |
| ) | |
| gr.Markdown(""" | |
| **📝 Tips for better results:** | |
| - Use clear, well-lit images | |
| - Center the clothing item | |
| - Use plain backgrounds when possible | |
| - Avoid multiple items in one image | |
| """) | |
| analyze_btn = gr.Button( | |
| "🔍 Analyze Item", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(elem_classes="output-section"): | |
| output_text = gr.Markdown( | |
| label="Analysis Results", | |
| value="**Results will appear here...**" | |
| ) | |
| output_image = gr.Image( | |
| label="Processed Image", | |
| interactive=False, | |
| height=300, | |
| show_download_button=False | |
| ) | |
| # Instructions | |
| gr.Markdown(""" | |
| ### 📋 Instructions: | |
| - **Upload** an image **OR** enter a direct image URL | |
| - Make sure the clothing item is clearly visible | |
| - Well-lit images work best | |
| - Avoid busy backgrounds | |
| - For best results, show one item at a time | |
| """) | |
| # Événement de click | |
| analyze_btn.click( | |
| fn=analyze_fashion_item, | |
| inputs=[image_input, url_input], | |
| outputs=[output_text, output_image] | |
| ) | |
| # ============================================================================= | |
| # API ENDPOINTS POUR LOVABLE | |
| # ============================================================================= | |
| app = FastAPI() | |
| async def api_classify(request: ClassificationRequest): | |
| """Endpoint API pour Lovable avec support des clés API""" | |
| try: | |
| # Vérification de la clé API (si configurée) | |
| if API_KEYS and API_KEYS[0] and API_KEYS[0] != "": | |
| if not request.api_key or request.api_key not in API_KEYS: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| if not request.image_data: | |
| raise HTTPException(status_code=400, detail="No image data provided") | |
| # Décode l'image base64 | |
| if ',' in request.image_data: | |
| request.image_data = request.image_data.split(',')[1] | |
| image_bytes = base64.b64decode(request.image_data) | |
| image = Image.open(BytesIO(image_bytes)) | |
| image = preprocess_image(image) | |
| # Analyse avec des inputs vides pour URL | |
| result_text, processed_image = analyze_fashion_item(image, "") | |
| if result_text.startswith("❌"): | |
| raise HTTPException(status_code=400, detail=result_text) | |
| # Convertir l'image traitée en base64 pour la réponse | |
| buffered = BytesIO() | |
| processed_image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return { | |
| "success": True, | |
| "result": result_text, | |
| "processed_image": f"data:image/jpeg;base64,{img_str}" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"API Error: {str(e)}") | |
| async def health_check(): | |
| """Endpoint de santé pour vérifier que l'API fonctionne""" | |
| return {"status": "healthy", "model_loaded": True} | |
| # Montrer l'interface Gradio sur la racine | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=False | |
| ) |