Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE | |
| MODEL_NAME = "rafalosa/diffusiondb-fashion-mnist" # Modèle spécialisé mode | |
| # Alternative: "nateraw/vit-base-patch16-224-fashion-mnist" | |
| print("🔄 Chargement du modèle de mode...") | |
| try: | |
| # Chargeur d'images avec prétraitement correct | |
| processor = AutoImageProcessor.from_pretrained( | |
| "google/vit-base-patch16-224", # Base standard | |
| cache_dir="cache" | |
| ) | |
| # Modèle fine-tuné sur la mode | |
| model = AutoModelForImageClassification.from_pretrained( | |
| MODEL_NAME, | |
| cache_dir="cache", | |
| trust_remote_code=True | |
| ) | |
| # Configuration device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| print(f"✅ Modèle chargé sur {device}") | |
| print(f"📊 Classe disponibles: {model.config.num_labels}") | |
| except Exception as e: | |
| print(f"❌ Erreur chargement: {e}") | |
| processor = None | |
| model = None | |
| # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE | |
| FASHION_LABELS = [ | |
| "T-shirt", "Pantalon", "Pull", "Robe", "Manteau", | |
| "Sandale", "Chemise", "Sneaker", "Sac", "Botte" | |
| ] | |
| def preprocess_image(image): | |
| """Prétraitement correct des images""" | |
| # Conversion en RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Redimensionnement intelligent | |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
| return image | |
| def classify_fashion(image): | |
| """Classification spécialisée mode""" | |
| try: | |
| if image is None: | |
| return "❌ Veuillez uploader une image de vêtement" | |
| if processor is None or model is None: | |
| return "⚠️ Modèle en cours de chargement... Patientez 30s" | |
| # 🔥 PRÉTRAITEMENT CORRECT | |
| processed_image = preprocess_image(image) | |
| # Transformation pour le modèle | |
| inputs = processor( | |
| images=processed_image, | |
| return_tensors="pt", | |
| do_resize=True, | |
| do_rescale=True, | |
| do_normalize=True | |
| ) | |
| # Transfert sur le bon device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 🔥 INFÉRENCE AVEC GRADIENTS DÉSACTIVÉS | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 🔥 POST-TRAITEMENT CORRECT | |
| probabilities = F.softmax(outputs.logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| # Conversion en résultats | |
| results = [] | |
| for i in range(len(top_indices[0])): | |
| # Utilisation de nos labels personnalisés | |
| label_idx = top_indices[0][i].item() | |
| label_name = FASHION_LABELS[label_idx % len(FASHION_LABELS)] | |
| score = top_probs[0][i].item() * 100 | |
| results.append({"label": label_name, "score": score}) | |
| # 📊 FORMATAGE DES RÉSULTATS | |
| output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n" | |
| for i, result in enumerate(results): | |
| output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n" | |
| # 📸 Aperçu de l'image traitée | |
| output += f"\n---\n" | |
| output += f"📏 Image traitée: 224x224 pixels\n" | |
| output += f"🔢 Modèle: {MODEL_NAME.split('/')[-1]}\n" | |
| output += "\n💡 **Pour de meilleurs résultats:**\n" | |
| output += "• Photo claire sur fond uni\n" | |
| output += "• Vêtement bien visible\n" | |
| output += "• Éviter les angles bizarres\n" | |
| return output | |
| except Exception as e: | |
| return f"❌ Erreur: {str(e)}\n\n🔧 Vérifiez les logs pour plus de détails" | |
| # 🖼️ EXEMPLES SPÉCIFIQUES MODE | |
| EXAMPLE_URLS = [ | |
| "https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt | |
| "https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe | |
| "https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise | |
| "https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste | |
| ] | |
| def load_example(url): | |
| """Charge un exemple depuis une URL""" | |
| try: | |
| response = requests.get(url, timeout=10) | |
| return Image.open(BytesIO(response.content)) | |
| except: | |
| return None | |
| # 🎨 INTERFACE AMÉLIORÉE | |
| with gr.Blocks( | |
| title="Classificateur de Mode Expert", | |
| theme=gr.themes.Soft(primary_hue="pink") | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS | |
| *Powered by Fine-Tuned Vision Transformer* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 UPLOADER") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Image de vêtement", | |
| height=300, | |
| sources=["upload", "clipboard"], | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| classify_btn = gr.Button("🚀 Classifier", variant="primary") | |
| clear_btn = gr.Button("🧹 Effacer", variant="secondary") | |
| gr.Markdown(""" | |
| ### 💡 CONSEILS | |
| - 📷 Photo claire et nette | |
| - 🎯 Vêtement bien centré | |
| - 🌟 Fond uni de préférence | |
| - ⚡ Attendez 3-5 secondes | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 RÉSULTATS") | |
| output_text = gr.Markdown( | |
| value="⬅️ Uploader une image ou utilisez les exemples ci-dessous" | |
| ) | |
| # 🎯 EXEMPLES INTERACTIFS | |
| gr.Markdown("### 🖼️ EXEMPLES DE TEST") | |
| with gr.Row(): | |
| for i, url in enumerate(EXAMPLE_URLS): | |
| gr.Examples( | |
| examples=[[url]], | |
| inputs=image_input, | |
| outputs=output_text, | |
| fn=classify_fashion, | |
| label=f"Exemple {i+1}", | |
| cache_examples=False | |
| ) | |
| # 🎮 INTERACTIONS | |
| classify_btn.click( | |
| fn=classify_fashion, | |
| inputs=[image_input], | |
| outputs=output_text, | |
| api_name="classify" | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "⬅️ Prêt pour une nouvelle image"), | |
| inputs=[], | |
| outputs=[image_input, output_text] | |
| ) | |
| # 🔄 AUTO-CLASSIFICATION | |
| image_input.upload( | |
| fn=classify_fashion, | |
| inputs=[image_input], | |
| outputs=output_text | |
| ) | |
| # ⚙️ CONFIGURATION | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True, | |
| show_error=True | |
| ) |