Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| import os | |
| import tempfile | |
| # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE | |
| MODEL_NAME = "google/vit-base-patch16-224" | |
| print("🔄 Chargement du modèle de mode...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| print(f"✅ Modèle chargé sur {device}") | |
| except Exception as e: | |
| print(f"❌ Erreur chargement: {e}") | |
| processor = None | |
| model = None | |
| # 🎯 MAPPING COMPLET DES CATÉGORIES EN FRANÇAIS | |
| FASHION_LABELS = { | |
| # Vêtements supérieurs | |
| 0: "T-shirt", 1: "Pull", 2: "Chemise", 3: "Sweat à capuche", 4: "Veste", | |
| 5: "Manteau", 6: "Blouse", 7: "Haut", 8: "Top", 9: "Débardeur", | |
| # Vêtements inférieurs | |
| 10: "Pantalon", 11: "Jean", 12: "Short", 13: "Jupe", 14: "Legging", | |
| 15: "Pantalon de sport", 16: "Pantalon cargo", 17: "Pantalon chino", | |
| # Robes et ensembles | |
| 18: "Robe", 19: "Robe de soirée", 20: "Robe d'été", 21: "Robe cocktail", | |
| 22: "Combinaison", 23: "Ensemble", 24: "Tenue", | |
| # Sous-vêtements | |
| 25: "Soutien-gorge", 26: "Culotte", 27: "Maillot de bain", | |
| 28: "Pyjama", 29: "Nuisette", | |
| # Chaussures | |
| 30: "Basket", 31: "Sandale", 32: "Botte", 33: "Talons", | |
| 34: "Escarpin", 35: "Chaussure de sport", 36: "Mocassin", | |
| 37: "Derby", 38: "Chausson", | |
| # Accessoires | |
| 39: "Sac à main", 40: "Sac à dos", 41: "Chapeau", 42: "Casquette", | |
| 43: "Écharpe", 44: "Gants", 45: "Ceinture", 46: "Lunettes de soleil", | |
| 47: "Bijou", 48: "Montre", 49: "Cravate", | |
| # Sports | |
| 50: "Tenue de sport", 51: "Maillot de football", 52: "Short de sport", | |
| 53: "Survêtement", 54: "Veste de sport", | |
| # Enfants | |
| 55: "Vêtement bébé", 56: "Vêtement enfant", | |
| # Divers | |
| 57: "Uniforme", 58: "Costume", 59: "Smoking", | |
| 60: "Robe de mariée", 61: "Accessoire mode", | |
| # Matières et textures (si le modèle les détecte) | |
| 100: "Coton", 101: "Denim", 102: "Laine", 103: "Soie", 104: "Cuir", | |
| 105: "Synthétique", 106: "Jean", 107: "Velours", 108: "Laine polaire", | |
| # Couleurs dominantes (approximatives) | |
| 200: "Vêtement noir", 201: "Vêtement blanc", 202: "Vêtement bleu", | |
| 203: "Vêtement rouge", 204: "Vêtement vert", 205: "Vêtement jaune", | |
| 206: "Vêtement rose", 207: "Vêtement violet", 208: "Vêtement orange", | |
| 209: "Vêtement marron", 210: "Vêtement gris", 211: "Vêtement multicolore", | |
| } | |
| # 🎨 CATÉGORIES GÉNÉRIQUES POUR LES NUMÉROS INCONNUS | |
| GENERIC_CATEGORIES = { | |
| range(600, 700): "Vêtement casual", | |
| range(700, 800): "Vêtement formel", | |
| range(800, 900): "Vêtement décontracté", | |
| range(900, 1000): "Article mode", | |
| } | |
| def get_human_readable_label(label_idx): | |
| """Convertit un numéro de catégorie en nom français""" | |
| # D'abord chercher dans le mapping précis | |
| if label_idx in FASHION_LABELS: | |
| return FASHION_LABELS[label_idx] | |
| # Ensuite chercher dans les catégories génériques | |
| for range_obj, category_name in GENERIC_CATEGORIES.items(): | |
| if label_idx in range_obj: | |
| return category_name | |
| # En dernier recours, catégorie générale | |
| if label_idx < 100: | |
| return "Vêtement supérieur" | |
| elif label_idx < 200: | |
| return "Vêtement inférieur" | |
| elif label_idx < 300: | |
| return "Accessoire mode" | |
| elif label_idx < 400: | |
| return "Chaussure" | |
| elif label_idx < 500: | |
| return "Vêtement sport" | |
| else: | |
| return "Article vestimentaire" | |
| def classify_fashion(image): | |
| """Classification avec noms en français""" | |
| try: | |
| if image is None: | |
| return "❌ Veuillez uploader une image de vêtement" | |
| if processor is None or model is None: | |
| return "⚠️ Modèle en cours de chargement... Patientez 30s" | |
| # 📸 Gestion de l'image | |
| try: | |
| if isinstance(image, str): | |
| processed_image = Image.open(image) | |
| else: | |
| processed_image = image | |
| if processed_image.mode != 'RGB': | |
| processed_image = processed_image.convert('RGB') | |
| except Exception as e: | |
| return f"❌ Format d'image non supporté: {str(e)}" | |
| # 🔥 PRÉTRAITEMENT | |
| processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS) | |
| # Transformation pour le modèle | |
| inputs = processor(images=processed_image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 🔥 INFÉRENCE | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 📊 POST-TRAITEMENT | |
| probabilities = F.softmax(outputs.logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| # Conversion en résultats français | |
| results = [] | |
| for i in range(len(top_indices[0])): | |
| label_idx = top_indices[0][i].item() | |
| label_name = get_human_readable_label(label_idx) | |
| score = top_probs[0][i].item() * 100 | |
| if score > 1.0: # Seuil de 1% pour éviter le bruit | |
| results.append({"label": label_name, "score": score}) | |
| # 📋 AFFICHAGE DES RÉSULTATS | |
| if not results: | |
| return "❌ Aucune catégorie vestimentaire détectée avec confiance suffisante" | |
| output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n" | |
| for i, result in enumerate(results): | |
| output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n" | |
| # 📊 STATISTIQUES | |
| total_confidence = sum(result['score'] for result in results) | |
| output += f"\n---\n" | |
| output += f"📈 **Confiance totale:** {total_confidence:.1f}%\n" | |
| # 💡 CONSEILS | |
| output += "\n💡 **Pour améliorer les résultats:**\n" | |
| output += "• Prenez la photo sur fond uni\n" | |
| output += "• Assurez-vous d'un bon éclairage\n" | |
| output += "• Cadrez uniquement le vêtement\n" | |
| output += "• Évitez les angles complexes\n" | |
| return output | |
| except Exception as e: | |
| return f"❌ Erreur de traitement: {str(e)}" | |
| # 🖼️ EXEMPLES DE TEST | |
| EXAMPLE_URLS = [ | |
| "https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt | |
| "https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe | |
| "https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise | |
| "https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste | |
| "https://images.unsplash.com/photo-1582142306909-195724d3a58c?w=400", # Jean | |
| ] | |
| # 🎨 INTERFACE AMÉLIORÉE | |
| with gr.Blocks(title="Classificateur de Vêtements Expert", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS | |
| *Reconnaissance intelligente avec labels en français* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 UPLOADER UNE IMAGE") | |
| image_input = gr.Image( | |
| type="filepath", | |
| label="Sélectionnez votre vêtement", | |
| height=300, | |
| sources=["upload"], | |
| ) | |
| gr.Markdown(""" | |
| ### 📋 CONSEILS | |
| ✅ JPEG/PNG recommandés | |
| ❌ Évitez HEIC (Apple) | |
| 📷 Photo nette et bien éclairée | |
| 🎯 Cadrage simple du vêtement | |
| """) | |
| classify_btn = gr.Button("🚀 Analyser le vêtement", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 RÉSULTATS DÉTAILLÉS") | |
| output_text = gr.Markdown( | |
| value="⬅️ Uploader une image pour commencer l'analyse" | |
| ) | |
| # 🎯 EXEMPLES | |
| gr.Markdown("### 🖼️ GARDIEN-ROBE DE TEST") | |
| gr.Examples( | |
| examples=EXAMPLE_URLS, | |
| inputs=image_input, | |
| outputs=output_text, | |
| fn=classify_fashion, | |
| label="Cliquez sur un vêtement pour tester" | |
| ) | |
| # 🎮 INTERACTION | |
| classify_btn.click( | |
| fn=classify_fashion, | |
| inputs=[image_input], | |
| outputs=output_text | |
| ) | |
| # ⚙️ CONFIGURATION | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) |