MODLI's picture
Update app.py
e26d51d verified
raw
history blame
4.47 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import os
from pathlib import Path
import tempfile
# 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
MODEL_NAME = "google/vit-base-patch16-224" # Modèle fiable et rapide
print("🔄 Chargement du modèle de mode...")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
print(f"✅ Modèle chargé sur {device}")
except Exception as e:
print(f"❌ Erreur chargement: {e}")
processor = None
model = None
# 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle)
FASHION_LABELS = {
0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau",
5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte",
10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire"
}
def convert_heic_to_jpeg(image_path):
"""Convertit les HEIC en JPEG si nécessaire"""
try:
if isinstance(image_path, str) and image_path.lower().endswith('.heic'):
# Conversion HEIC → JPEG
img = Image.open(image_path)
jpeg_path = image_path.replace('.heic', '.jpeg')
img.convert('RGB').save(jpeg_path, 'JPEG')
return jpeg_path
except:
pass
return image_path
def preprocess_image(image):
"""Prétraitement robuste des images"""
try:
# Si c'est un chemin de fichier (HEIC)
if isinstance(image, str):
image = convert_heic_to_jpeg(image)
image = Image.open(image)
# Conversion en RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Redimensionnement
image = image.resize((224, 224), Image.Resampling.LANCZOS)
return image
except Exception as e:
raise Exception(f"Erreur prétraitement: {str(e)}")
def classify_fashion(image):
"""Classification avec gestion robuste des formats"""
try:
if image is None:
return "❌ Veuillez uploader une image de vêtement"
if processor is None or model is None:
return "⚠️ Modèle en cours de chargement... Patientez 30s"
# 📸 Gestion spéciale HEIC et formats complexes
try:
# Si l'image est un chemin temporaire (format HEIC)
if isinstance(image, str) and ('gradio' in image or 'tmp' in image):
if image.lower().endswith('.heic'):
# Conversion HEIC → JPEG
img = Image.open(image)
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
img.convert('RGB').save(tmp.name, 'JPEG', quality=95)
processed_image = Image.open(tmp.name)
os.unlink(tmp.name) # Nettoyage
else:
processed_image = Image.open(image)
else:
# Image normale
processed_image = image
# Conversion en RGB si nécessaire
if processed_image.mode != 'RGB':
processed_image = processed_image.convert('RGB')
except Exception as e:
return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP"
# 🔥 PRÉTRAITEMENT CORRECT
processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS)
# Transformation pour le modèle
inputs = processor(images=processed_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# 🔥 INFÉRENCE
with torch.no_grad():
outputs = model(**inputs)
# 📊 POST-TRAITEMENT
probabilities = F.softmax(outputs.logits, dim=-1)
top_probs, top_indices = torch.topk(probabilities, 5)
# Conversion en résultats
results = []
for i in range(len(top_indices[0])):
label_idx = top_indices[0][i].item()
label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}")