import gradio as gr import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO import numpy as np # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE MODEL_NAME = "rafalosa/diffusiondb-fashion-mnist" # Modèle spécialisé mode # Alternative: "nateraw/vit-base-patch16-224-fashion-mnist" print("🔄 Chargement du modèle de mode...") try: # Chargeur d'images avec prétraitement correct processor = AutoImageProcessor.from_pretrained( "google/vit-base-patch16-224", # Base standard cache_dir="cache" ) # Modèle fine-tuné sur la mode model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, cache_dir="cache", trust_remote_code=True ) # Configuration device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() print(f"✅ Modèle chargé sur {device}") print(f"📊 Classe disponibles: {model.config.num_labels}") except Exception as e: print(f"❌ Erreur chargement: {e}") processor = None model = None # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE FASHION_LABELS = [ "T-shirt", "Pantalon", "Pull", "Robe", "Manteau", "Sandale", "Chemise", "Sneaker", "Sac", "Botte" ] def preprocess_image(image): """Prétraitement correct des images""" # Conversion en RGB if image.mode != 'RGB': image = image.convert('RGB') # Redimensionnement intelligent image = image.resize((224, 224), Image.Resampling.LANCZOS) return image def classify_fashion(image): """Classification spécialisée mode""" try: if image is None: return "❌ Veuillez uploader une image de vêtement" if processor is None or model is None: return "⚠️ Modèle en cours de chargement... Patientez 30s" # 🔥 PRÉTRAITEMENT CORRECT processed_image = preprocess_image(image) # Transformation pour le modèle inputs = processor( images=processed_image, return_tensors="pt", do_resize=True, do_rescale=True, do_normalize=True ) # Transfert sur le bon device inputs = {k: v.to(device) for k, v in inputs.items()} # 🔥 INFÉRENCE AVEC GRADIENTS DÉSACTIVÉS with torch.no_grad(): outputs = model(**inputs) # 🔥 POST-TRAITEMENT CORRECT probabilities = F.softmax(outputs.logits, dim=-1) top_probs, top_indices = torch.topk(probabilities, 5) # Conversion en résultats results = [] for i in range(len(top_indices[0])): # Utilisation de nos labels personnalisés label_idx = top_indices[0][i].item() label_name = FASHION_LABELS[label_idx % len(FASHION_LABELS)] score = top_probs[0][i].item() * 100 results.append({"label": label_name, "score": score}) # 📊 FORMATAGE DES RÉSULTATS output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n" for i, result in enumerate(results): output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n" # 📸 Aperçu de l'image traitée output += f"\n---\n" output += f"📏 Image traitée: 224x224 pixels\n" output += f"🔢 Modèle: {MODEL_NAME.split('/')[-1]}\n" output += "\n💡 **Pour de meilleurs résultats:**\n" output += "• Photo claire sur fond uni\n" output += "• Vêtement bien visible\n" output += "• Éviter les angles bizarres\n" return output except Exception as e: return f"❌ Erreur: {str(e)}\n\n🔧 Vérifiez les logs pour plus de détails" # 🖼️ EXEMPLES SPÉCIFIQUES MODE EXAMPLE_URLS = [ "https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt "https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe "https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise "https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste ] def load_example(url): """Charge un exemple depuis une URL""" try: response = requests.get(url, timeout=10) return Image.open(BytesIO(response.content)) except: return None # 🎨 INTERFACE AMÉLIORÉE with gr.Blocks( title="Classificateur de Mode Expert", theme=gr.themes.Soft(primary_hue="pink") ) as demo: gr.Markdown(""" # 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS *Powered by Fine-Tuned Vision Transformer* """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 UPLOADER") image_input = gr.Image( type="pil", label="Image de vêtement", height=300, sources=["upload", "clipboard"], interactive=True ) with gr.Row(): classify_btn = gr.Button("🚀 Classifier", variant="primary") clear_btn = gr.Button("🧹 Effacer", variant="secondary") gr.Markdown(""" ### 💡 CONSEILS - 📷 Photo claire et nette - 🎯 Vêtement bien centré - 🌟 Fond uni de préférence - ⚡ Attendez 3-5 secondes """) with gr.Column(scale=2): gr.Markdown("### 📊 RÉSULTATS") output_text = gr.Markdown( value="⬅️ Uploader une image ou utilisez les exemples ci-dessous" ) # 🎯 EXEMPLES INTERACTIFS gr.Markdown("### 🖼️ EXEMPLES DE TEST") with gr.Row(): for i, url in enumerate(EXAMPLE_URLS): gr.Examples( examples=[[url]], inputs=image_input, outputs=output_text, fn=classify_fashion, label=f"Exemple {i+1}", cache_examples=False ) # 🎮 INTERACTIONS classify_btn.click( fn=classify_fashion, inputs=[image_input], outputs=output_text, api_name="classify" ) clear_btn.click( fn=lambda: (None, "⬅️ Prêt pour une nouvelle image"), inputs=[], outputs=[image_input, output_text] ) # 🔄 AUTO-CLASSIFICATION image_input.upload( fn=classify_fashion, inputs=[image_input], outputs=output_text ) # ⚙️ CONFIGURATION if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True )