Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| # Chargement du modèle spécialisé dans la mode | |
| MODEL_NAME = "google/vit-base-patch16-224" # Modèle de base fiable | |
| # Alternative: "nateraw/fashion-clip" si disponible | |
| # Initialisation du modèle | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Utilisation du device: {device}") | |
| try: | |
| # Chargeur d'images | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| # Modèle de classification | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| print("✅ Modèle chargé avec succès!") | |
| except Exception as e: | |
| print(f"❌ Erreur chargement modèle: {e}") | |
| processor = None | |
| model = None | |
| def classify_clothing(image): | |
| """Classifie une image de vêtement""" | |
| try: | |
| if image is None: | |
| return "❌ Veuillez uploader une image de vêtement" | |
| if processor is None or model is None: | |
| return "⚠️ Modèle en cours de chargement... Réessayez dans 30 secondes" | |
| # Prétraitement de l'image | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Classification | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Récupération des résultats | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| # Conversion en résultats lisibles | |
| results = [] | |
| for i in range(len(top_indices[0])): | |
| label = model.config.id2label[top_indices[0][i].item()] | |
| score = top_probs[0][i].item() * 100 | |
| results.append({"label": label, "score": score}) | |
| # Formatage des résultats | |
| output = "## 🎯 Résultats de Classification:\n\n" | |
| for i, result in enumerate(results): | |
| # Nettoyage des labels | |
| clean_label = result['label'].replace('_', ' ').title() | |
| output += f"{i+1}. **{clean_label}** - {result['score']:.1f}%\n" | |
| output += "\n---\n" | |
| output += "💡 **Conseils pour de meilleurs résultats:**\n" | |
| output += "• Utilisez des images claires sur fond uni\n" | |
| output += "• Cadrez bien le vêtement\n" | |
| output += "• Évitez les images avec plusieurs personnes\n" | |
| return output | |
| except Exception as e: | |
| return f"❌ Erreur lors de la classification: {str(e)}" | |
| def load_example_image(url): | |
| """Charge une image d'exemple depuis une URL""" | |
| try: | |
| response = requests.get(url, timeout=10) | |
| image = Image.open(BytesIO(response.content)) | |
| return image | |
| except: | |
| return None | |
| # Exemples d'images de test | |
| example_images = [ | |
| ["https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400"], # T-shirt | |
| ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400"], # Robe | |
| ["https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400"], # Chemise | |
| ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"], # Veste | |
| ] | |
| # Interface Gradio | |
| with gr.Blocks(title="Classificateur de Vêtements", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 👗 Classificateur de Vêtements Intelligent | |
| **Uploader une image de vêtement** pour obtenir sa classification automatique | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 Uploader votre image") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Image de vêtement", | |
| height=300, | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| gr.Markdown("### 🎯 Actions") | |
| classify_btn = gr.Button("🚀 Classifier", variant="primary") | |
| clear_btn = gr.Button("🧹 Effacer", variant="secondary") | |
| gr.Markdown("### 💡 Conseils") | |
| gr.Markdown(""" | |
| - Images claires et bien éclairées | |
| - Vêtement visible et bien cadré | |
| - Fond simple de préférence | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 Résultats") | |
| output_text = gr.Markdown( | |
| value="⬅️ Uploader une image ou choisissez un exemple ci-dessous" | |
| ) | |
| # Section exemples | |
| gr.Markdown("### 🖼️ Exemples à tester") | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=image_input, | |
| outputs=output_text, | |
| fn=classify_clothing, | |
| label="Cliquez sur une image pour tester", | |
| cache_examples=True | |
| ) | |
| # Événements | |
| classify_btn.click( | |
| fn=classify_clothing, | |
| inputs=[image_input], | |
| outputs=output_text | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "⬅️ Uploader une nouvelle image"), | |
| inputs=[], | |
| outputs=[image_input, output_text] | |
| ) | |
| # Classification automatique au changement | |
| image_input.change( | |
| fn=classify_clothing, | |
| inputs=[image_input], | |
| outputs=output_text | |
| ) | |
| # Configuration | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) |