Spaces:
Runtime error
Runtime error
File size: 4,465 Bytes
fd50bed 140fdb2 fd50bed 140fdb2 e26d51d fd50bed 140fdb2 e26d51d fd50bed 140fdb2 fd50bed e26d51d 140fdb2 fd50bed 140fdb2 fd50bed 140fdb2 fd50bed e26d51d 140fdb2 e26d51d 140fdb2 e26d51d fd50bed 140fdb2 e26d51d 140fdb2 e26d51d 140fdb2 e26d51d fd50bed e26d51d fd50bed e26d51d 140fdb2 fd50bed 140fdb2 fd50bed 140fdb2 e26d51d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import os
from pathlib import Path
import tempfile
# 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
MODEL_NAME = "google/vit-base-patch16-224" # Modèle fiable et rapide
print("🔄 Chargement du modèle de mode...")
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
print(f"✅ Modèle chargé sur {device}")
except Exception as e:
print(f"❌ Erreur chargement: {e}")
processor = None
model = None
# 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle)
FASHION_LABELS = {
0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau",
5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte",
10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire"
}
def convert_heic_to_jpeg(image_path):
"""Convertit les HEIC en JPEG si nécessaire"""
try:
if isinstance(image_path, str) and image_path.lower().endswith('.heic'):
# Conversion HEIC → JPEG
img = Image.open(image_path)
jpeg_path = image_path.replace('.heic', '.jpeg')
img.convert('RGB').save(jpeg_path, 'JPEG')
return jpeg_path
except:
pass
return image_path
def preprocess_image(image):
"""Prétraitement robuste des images"""
try:
# Si c'est un chemin de fichier (HEIC)
if isinstance(image, str):
image = convert_heic_to_jpeg(image)
image = Image.open(image)
# Conversion en RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Redimensionnement
image = image.resize((224, 224), Image.Resampling.LANCZOS)
return image
except Exception as e:
raise Exception(f"Erreur prétraitement: {str(e)}")
def classify_fashion(image):
"""Classification avec gestion robuste des formats"""
try:
if image is None:
return "❌ Veuillez uploader une image de vêtement"
if processor is None or model is None:
return "⚠️ Modèle en cours de chargement... Patientez 30s"
# 📸 Gestion spéciale HEIC et formats complexes
try:
# Si l'image est un chemin temporaire (format HEIC)
if isinstance(image, str) and ('gradio' in image or 'tmp' in image):
if image.lower().endswith('.heic'):
# Conversion HEIC → JPEG
img = Image.open(image)
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
img.convert('RGB').save(tmp.name, 'JPEG', quality=95)
processed_image = Image.open(tmp.name)
os.unlink(tmp.name) # Nettoyage
else:
processed_image = Image.open(image)
else:
# Image normale
processed_image = image
# Conversion en RGB si nécessaire
if processed_image.mode != 'RGB':
processed_image = processed_image.convert('RGB')
except Exception as e:
return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP"
# 🔥 PRÉTRAITEMENT CORRECT
processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS)
# Transformation pour le modèle
inputs = processor(images=processed_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# 🔥 INFÉRENCE
with torch.no_grad():
outputs = model(**inputs)
# 📊 POST-TRAITEMENT
probabilities = F.softmax(outputs.logits, dim=-1)
top_probs, top_indices = torch.topk(probabilities, 5)
# Conversion en résultats
results = []
for i in range(len(top_indices[0])):
label_idx = top_indices[0][i].item()
label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}") |