MODLI's picture
Update app.py
140fdb2 verified
raw
history blame
6.91 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
# 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
MODEL_NAME = "rafalosa/diffusiondb-fashion-mnist" # Modèle spécialisé mode
# Alternative: "nateraw/vit-base-patch16-224-fashion-mnist"
print("🔄 Chargement du modèle de mode...")
try:
# Chargeur d'images avec prétraitement correct
processor = AutoImageProcessor.from_pretrained(
"google/vit-base-patch16-224", # Base standard
cache_dir="cache"
)
# Modèle fine-tuné sur la mode
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
cache_dir="cache",
trust_remote_code=True
)
# Configuration device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
print(f"✅ Modèle chargé sur {device}")
print(f"📊 Classe disponibles: {model.config.num_labels}")
except Exception as e:
print(f"❌ Erreur chargement: {e}")
processor = None
model = None
# 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE
FASHION_LABELS = [
"T-shirt", "Pantalon", "Pull", "Robe", "Manteau",
"Sandale", "Chemise", "Sneaker", "Sac", "Botte"
]
def preprocess_image(image):
"""Prétraitement correct des images"""
# Conversion en RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Redimensionnement intelligent
image = image.resize((224, 224), Image.Resampling.LANCZOS)
return image
def classify_fashion(image):
"""Classification spécialisée mode"""
try:
if image is None:
return "❌ Veuillez uploader une image de vêtement"
if processor is None or model is None:
return "⚠️ Modèle en cours de chargement... Patientez 30s"
# 🔥 PRÉTRAITEMENT CORRECT
processed_image = preprocess_image(image)
# Transformation pour le modèle
inputs = processor(
images=processed_image,
return_tensors="pt",
do_resize=True,
do_rescale=True,
do_normalize=True
)
# Transfert sur le bon device
inputs = {k: v.to(device) for k, v in inputs.items()}
# 🔥 INFÉRENCE AVEC GRADIENTS DÉSACTIVÉS
with torch.no_grad():
outputs = model(**inputs)
# 🔥 POST-TRAITEMENT CORRECT
probabilities = F.softmax(outputs.logits, dim=-1)
top_probs, top_indices = torch.topk(probabilities, 5)
# Conversion en résultats
results = []
for i in range(len(top_indices[0])):
# Utilisation de nos labels personnalisés
label_idx = top_indices[0][i].item()
label_name = FASHION_LABELS[label_idx % len(FASHION_LABELS)]
score = top_probs[0][i].item() * 100
results.append({"label": label_name, "score": score})
# 📊 FORMATAGE DES RÉSULTATS
output = "## 🎯 RÉSULTATS DE CLASSIFICATION:\n\n"
for i, result in enumerate(results):
output += f"{i+1}. **{result['label']}** - {result['score']:.1f}%\n"
# 📸 Aperçu de l'image traitée
output += f"\n---\n"
output += f"📏 Image traitée: 224x224 pixels\n"
output += f"🔢 Modèle: {MODEL_NAME.split('/')[-1]}\n"
output += "\n💡 **Pour de meilleurs résultats:**\n"
output += "• Photo claire sur fond uni\n"
output += "• Vêtement bien visible\n"
output += "• Éviter les angles bizarres\n"
return output
except Exception as e:
return f"❌ Erreur: {str(e)}\n\n🔧 Vérifiez les logs pour plus de détails"
# 🖼️ EXEMPLES SPÉCIFIQUES MODE
EXAMPLE_URLS = [
"https://images.unsplash.com/photo-1558769132-cb1aea458c5e?w=400", # T-shirt
"https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=400", # Robe
"https://images.unsplash.com/photo-1529111290557-82f6d5c6cf85?w=400", # Chemise
"https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400", # Veste
]
def load_example(url):
"""Charge un exemple depuis une URL"""
try:
response = requests.get(url, timeout=10)
return Image.open(BytesIO(response.content))
except:
return None
# 🎨 INTERFACE AMÉLIORÉE
with gr.Blocks(
title="Classificateur de Mode Expert",
theme=gr.themes.Soft(primary_hue="pink")
) as demo:
gr.Markdown("""
# 👗 CLASSIFICATEUR EXPERT DE VÊTEMENTS
*Powered by Fine-Tuned Vision Transformer*
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📤 UPLOADER")
image_input = gr.Image(
type="pil",
label="Image de vêtement",
height=300,
sources=["upload", "clipboard"],
interactive=True
)
with gr.Row():
classify_btn = gr.Button("🚀 Classifier", variant="primary")
clear_btn = gr.Button("🧹 Effacer", variant="secondary")
gr.Markdown("""
### 💡 CONSEILS
- 📷 Photo claire et nette
- 🎯 Vêtement bien centré
- 🌟 Fond uni de préférence
- ⚡ Attendez 3-5 secondes
""")
with gr.Column(scale=2):
gr.Markdown("### 📊 RÉSULTATS")
output_text = gr.Markdown(
value="⬅️ Uploader une image ou utilisez les exemples ci-dessous"
)
# 🎯 EXEMPLES INTERACTIFS
gr.Markdown("### 🖼️ EXEMPLES DE TEST")
with gr.Row():
for i, url in enumerate(EXAMPLE_URLS):
gr.Examples(
examples=[[url]],
inputs=image_input,
outputs=output_text,
fn=classify_fashion,
label=f"Exemple {i+1}",
cache_examples=False
)
# 🎮 INTERACTIONS
classify_btn.click(
fn=classify_fashion,
inputs=[image_input],
outputs=output_text,
api_name="classify"
)
clear_btn.click(
fn=lambda: (None, "⬅️ Prêt pour une nouvelle image"),
inputs=[],
outputs=[image_input, output_text]
)
# 🔄 AUTO-CLASSIFICATION
image_input.upload(
fn=classify_fashion,
inputs=[image_input],
outputs=output_text
)
# ⚙️ CONFIGURATION
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
show_error=True
)