MODLI's picture
Update app.py
fbaf2a1 verified
raw
history blame
8.84 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import os
import tempfile
# 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
MODEL_NAME = "google/vit-base-patch16-224"
print("🔄 Chargement du modèle de mode...")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
print(f"✅ Modèle chargé sur {device}")
except Exception as e:
print(f"❌ Erreur chargement: {e}")
processor = None
model = None
# 🎯 MAPPING COMPLET DES CATÉGORIES EN FRANÇAIS
FASHION_LABELS = {
# Vêtements supérieurs
0: "T-shirt", 1: "Pull", 2: "Chemise", 3: "Sweat à capuche", 4: "Veste",
5: "Manteau", 6: "Blouse", 7: "Haut", 8: "Top", 9: "Débardeur",
# Vêtements inférieurs
10: "Pantalon", 11: "Jean", 12: "Short", 13: "Jupe", 14: "Legging",
15: "Pantalon de sport", 16: "Pantalon cargo", 17: "Pantalon chino",
# Robes et ensembles
18: "Robe", 19: "Robe de soirée", 20: "Robe d'été", 21: "Robe cocktail",
22: "Combinaison", 23: "Ensemble", 24: "Tenue",
# Sous-vêtements
25: "Soutien-gorge", 26: "Culotte", 27: "Maillot de bain",
28: "Pyjama", 29: "Nuisette",
# Chaussures
30: "Basket", 31: "Sandale", 32: "Botte", 33: "Talons",
34: "Escarpin", 35: "Chaussure de sport", 36: "Mocassin",
37: "Derby", 38: "Chausson",
# Accessoires
39: "Sac à main", 40: "Sac à dos", 41: "Chapeau", 42: "Casquette",
43: "Écharpe", 44: "Gants", 45: "Ceinture", 46: "Lunettes de soleil",
47: "Bijou", 48: "Montre", 49: "Cravate",
# Sports
50: "Tenue de sport", 51: "Maillot de football", 52: "Short de sport",
53: "Survêtement", 54: "Veste de sport",
# Enfants
55: "Vêtement bébé", 56: "Vêtement enfant",
# Divers
57: "Uniforme", 58: "Costume", 59: "Smoking",
60: "Robe de mariée", 61: "Accessoire mode",
# Matières et textures (si le modèle les détecte)
100: "Coton", 101: "Denim", 102: "Laine", 103: "Soie", 104: "Cuir",
105: "Synthétique", 106: "Jean", 107: "Velours", 108: "Laine polaire",
# Couleurs dominantes (approximatives)
200: "Vêtement noir", 201: "Vêtement blanc", 202: "Vêtement bleu",
203: "Vêtement rouge", 204: "Vêtement vert", 205: "Vêtement jaune",
206: "Vêtement rose", 207: "Vêtement violet", 208: "Vêtement orange",
209: "Vêtement marron", 210: "Vêtement gris", 211: "Vêtement multicolore",
}
# 🎨 CATÉGORIES GÉNÉRIQUES POUR LES NUMÉROS INCONNUS
GENERIC_CATEGORIES = {
range(600, 700): "Vêtement casual",
range(700, 800): "Vêtement formel",
range(800, 900): "Vêtement décontracté",
range(900, 1000): "Article mode",
}
def get_human_readable_label(label_idx):
"""Convertit un numéro de catégorie en nom français"""
# D'abord chercher dans le mapping précis
if label_idx in FASHION_LABELS:
return FASHION_LABELS[label_idx]
# Ensuite chercher dans les catégories génériques
for range_obj, category_name in GENERIC_CATEGORIES.items():
if label_idx in range_obj:
return category_name
# En dernier recours, catégorie générale
if label_idx < 100:
return "Vêtement supérieur"
elif label_idx < 200:
return "Vêtement inférieur"
elif label_idx < 300:
return "Accessoire mode"
elif label_idx < 400:
return "Chaussure"
elif label_idx < 500:
return "Vêtement sport"
else:
return "Article vestimentaire"
def classify_fashion(image):
"""Classification avec noms en français"""
try:
if image is None:
return "❌ Veuillez uploader une image de vêtement"
if processor is None or model is None:
return "⚠️ Modèle en cours de chargement... Patientez 30s"
# 📸 Gestion de l'image
try:
if isinstance(image, str):
processed_image = Image.open(image)
else:
processed_image = image
if processed_image.mode != 'RGB':
processed_image = processed_image.convert('RGB')
except Exception as e:
return f"❌ Format d'image non supporté: {str(e)}"
# 🔥 PRÉTRAITEMENT
processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS)
# Transformation pour le modèle
inputs = processor(images=processed_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# 🔥 INFÉRENCE
with torch.no_grad():
outputs = model(**inputs)
# 📊 POST-TRAITEMENT
probabilities = F.softmax(outputs.logits, dim=-1)
top_probs, top_indices = torch.topk(probabilities, 5)
# Conversion en résultats français
results = []
for i in range(len(top_indices[0])):
label_idx = top_indices[0][i].item()
label_name = get_human_readable_label(label_idx)
score = top_probs[0][i].item() * 100
if score > 1.0: # Seuil de 1% pour éviter le bruit
results.append({"label": label_name, "score": score})
# 📋 AFFICHAGE DES RÉSULTATS
if not results:
return "❌ Aucune catégorie vestimentaire détectée avec confiance suffisante"
output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n"
for i, result in enumerate(results):
output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n"
# 📊 STATISTIQUES
total_confidence = sum(result['score'] for result in results)
output += f"\n---\n"
output += f"📈 **Confiance totale:** {total_confidence:.1f}%\n"
# 💡 CONSEILS
output += "\n💡 **Pour améliorer les résultats:**\n"
output += "• Prenez la photo sur fond uni\n"
output += "• Assurez-vous d'un bon éclairage\n"
output += "• Cadrez uniquement le vêtement\n"
output += "• Évitez les angles complexes\n"
return output
except Exception as e:
return f"❌ Erreur de traitement: {str(e)}"
# 🖼️ EXEMPLES DE TEST
EXAMPLE_URLS = [
"https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt
"https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe
"https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise
"https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste
"https://images.unsplash.com/photo-1582142306909-195724d3a58c?w=400", # Jean
]
# 🎨 INTERFACE AMÉLIORÉE
with gr.Blocks(title="Classificateur de Vêtements Expert", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS
*Reconnaissance intelligente avec labels en français*
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📤 UPLOADER UNE IMAGE")
image_input = gr.Image(
type="filepath",
label="Sélectionnez votre vêtement",
height=300,
sources=["upload"],
)
gr.Markdown("""
### 📋 CONSEILS
✅ JPEG/PNG recommandés
❌ Évitez HEIC (Apple)
📷 Photo nette et bien éclairée
🎯 Cadrage simple du vêtement
""")
classify_btn = gr.Button("🚀 Analyser le vêtement", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 📊 RÉSULTATS DÉTAILLÉS")
output_text = gr.Markdown(
value="⬅️ Uploader une image pour commencer l'analyse"
)
# 🎯 EXEMPLES
gr.Markdown("### 🖼️ GARDIEN-ROBE DE TEST")
gr.Examples(
examples=EXAMPLE_URLS,
inputs=image_input,
outputs=output_text,
fn=classify_fashion,
label="Cliquez sur un vêtement pour tester"
)
# 🎮 INTERACTION
classify_btn.click(
fn=classify_fashion,
inputs=[image_input],
outputs=output_text
)
# ⚙️ CONFIGURATION
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)