import gradio as gr import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO # Chargement du modèle spécialisé dans la mode MODEL_NAME = "google/vit-base-patch16-224" # Modèle de base fiable # Alternative: "nateraw/fashion-clip" si disponible # Initialisation du modèle device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🖥️ Utilisation du device: {device}") try: # Chargeur d'images processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # Modèle de classification model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) model.to(device) model.eval() print("✅ Modèle chargé avec succès!") except Exception as e: print(f"❌ Erreur chargement modèle: {e}") processor = None model = None def classify_clothing(image): """Classifie une image de vêtement""" try: if image is None: return "❌ Veuillez uploader une image de vêtement" if processor is None or model is None: return "⚠️ Modèle en cours de chargement... Réessayez dans 30 secondes" # Prétraitement de l'image inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Classification with torch.no_grad(): outputs = model(**inputs) # Récupération des résultats probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) top_probs, top_indices = torch.topk(probabilities, 5) # Conversion en résultats lisibles results = [] for i in range(len(top_indices[0])): label = model.config.id2label[top_indices[0][i].item()] score = top_probs[0][i].item() * 100 results.append({"label": label, "score": score}) # Formatage des résultats output = "## 🎯 Résultats de Classification:\n\n" for i, result in enumerate(results): # Nettoyage des labels clean_label = result['label'].replace('_', ' ').title() output += f"{i+1}. **{clean_label}** - {result['score']:.1f}%\n" output += "\n---\n" output += "💡 **Conseils pour de meilleurs résultats:**\n" output += "• Utilisez des images claires sur fond uni\n" output += "• Cadrez bien le vêtement\n" output += "• Évitez les images avec plusieurs personnes\n" return output except Exception as e: return f"❌ Erreur lors de la classification: {str(e)}" def load_example_image(url): """Charge une image d'exemple depuis une URL""" try: response = requests.get(url, timeout=10) image = Image.open(BytesIO(response.content)) return image except: return None # Exemples d'images de test example_images = [ ["https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400"], # T-shirt ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400"], # Robe ["https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400"], # Chemise ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"], # Veste ] # Interface Gradio with gr.Blocks(title="Classificateur de Vêtements", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 👗 Classificateur de Vêtements Intelligent **Uploader une image de vêtement** pour obtenir sa classification automatique """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 Uploader votre image") image_input = gr.Image( type="pil", label="Image de vêtement", height=300, sources=["upload", "webcam", "clipboard"] ) gr.Markdown("### 🎯 Actions") classify_btn = gr.Button("🚀 Classifier", variant="primary") clear_btn = gr.Button("🧹 Effacer", variant="secondary") gr.Markdown("### 💡 Conseils") gr.Markdown(""" - Images claires et bien éclairées - Vêtement visible et bien cadré - Fond simple de préférence """) with gr.Column(scale=2): gr.Markdown("### 📊 Résultats") output_text = gr.Markdown( value="⬅️ Uploader une image ou choisissez un exemple ci-dessous" ) # Section exemples gr.Markdown("### 🖼️ Exemples à tester") gr.Examples( examples=example_images, inputs=image_input, outputs=output_text, fn=classify_clothing, label="Cliquez sur une image pour tester", cache_examples=True ) # Événements classify_btn.click( fn=classify_clothing, inputs=[image_input], outputs=output_text ) clear_btn.click( fn=lambda: (None, "⬅️ Uploader une nouvelle image"), inputs=[], outputs=[image_input, output_text] ) # Classification automatique au changement image_input.change( fn=classify_clothing, inputs=[image_input], outputs=output_text ) # Configuration if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True )