Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import numpy as np | |
| import os | |
| from pathlib import Path | |
| import tempfile | |
| # 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE | |
| MODEL_NAME = "google/vit-base-patch16-224" # Modèle fiable et rapide | |
| print("🔄 Chargement du modèle de mode...") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| print(f"✅ Modèle chargé sur {device}") | |
| except Exception as e: | |
| print(f"❌ Erreur chargement: {e}") | |
| processor = None | |
| model = None | |
| # 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle) | |
| FASHION_LABELS = { | |
| 0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau", | |
| 5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte", | |
| 10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire" | |
| } | |
| def convert_heic_to_jpeg(image_path): | |
| """Convertit les HEIC en JPEG si nécessaire""" | |
| try: | |
| if isinstance(image_path, str) and image_path.lower().endswith('.heic'): | |
| # Conversion HEIC → JPEG | |
| img = Image.open(image_path) | |
| jpeg_path = image_path.replace('.heic', '.jpeg') | |
| img.convert('RGB').save(jpeg_path, 'JPEG') | |
| return jpeg_path | |
| except: | |
| pass | |
| return image_path | |
| def preprocess_image(image): | |
| """Prétraitement robuste des images""" | |
| try: | |
| # Si c'est un chemin de fichier (HEIC) | |
| if isinstance(image, str): | |
| image = convert_heic_to_jpeg(image) | |
| image = Image.open(image) | |
| # Conversion en RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Redimensionnement | |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
| return image | |
| except Exception as e: | |
| raise Exception(f"Erreur prétraitement: {str(e)}") | |
| def classify_fashion(image): | |
| """Classification avec gestion robuste des formats""" | |
| try: | |
| if image is None: | |
| return "❌ Veuillez uploader une image de vêtement" | |
| if processor is None or model is None: | |
| return "⚠️ Modèle en cours de chargement... Patientez 30s" | |
| # 📸 Gestion spéciale HEIC et formats complexes | |
| try: | |
| # Si l'image est un chemin temporaire (format HEIC) | |
| if isinstance(image, str) and ('gradio' in image or 'tmp' in image): | |
| if image.lower().endswith('.heic'): | |
| # Conversion HEIC → JPEG | |
| img = Image.open(image) | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: | |
| img.convert('RGB').save(tmp.name, 'JPEG', quality=95) | |
| processed_image = Image.open(tmp.name) | |
| os.unlink(tmp.name) # Nettoyage | |
| else: | |
| processed_image = Image.open(image) | |
| else: | |
| # Image normale | |
| processed_image = image | |
| # Conversion en RGB si nécessaire | |
| if processed_image.mode != 'RGB': | |
| processed_image = processed_image.convert('RGB') | |
| except Exception as e: | |
| return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP" | |
| # 🔥 PRÉTRAITEMENT CORRECT | |
| processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS) | |
| # Transformation pour le modèle | |
| inputs = processor(images=processed_image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 🔥 INFÉRENCE | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 📊 POST-TRAITEMENT | |
| probabilities = F.softmax(outputs.logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probabilities, 5) | |
| # Conversion en résultats | |
| results = [] | |
| for i in range(len(top_indices[0])): | |
| label_idx = top_indices[0][i].item() | |
| label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}") |