File size: 4,465 Bytes
fd50bed
 
140fdb2
fd50bed
 
 
 
140fdb2
e26d51d
 
 
fd50bed
140fdb2
e26d51d
fd50bed
140fdb2
fd50bed
 
e26d51d
 
140fdb2
 
fd50bed
 
140fdb2
 
 
fd50bed
140fdb2
fd50bed
 
 
e26d51d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140fdb2
 
e26d51d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140fdb2
 
e26d51d
fd50bed
 
 
 
 
140fdb2
 
e26d51d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140fdb2
e26d51d
140fdb2
 
e26d51d
fd50bed
 
e26d51d
fd50bed
 
 
e26d51d
140fdb2
fd50bed
 
140fdb2
fd50bed
 
140fdb2
e26d51d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import os
from pathlib import Path
import tempfile

# 🔥 MODÈLE SPÉCIALISÉ DANS LA MODE
MODEL_NAME = "google/vit-base-patch16-224"  # Modèle fiable et rapide

print("🔄 Chargement du modèle de mode...")

try:
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()
    
    print(f"✅ Modèle chargé sur {device}")

except Exception as e:
    print(f"❌ Erreur chargement: {e}")
    processor = None
    model = None

# 🎯 LABELS COMPRÉHENSIBLES POUR LA MODE (adaptés au modèle)
FASHION_LABELS = {
    0: "T-shirt", 1: "Pantalon", 2: "Pull", 3: "Robe", 4: "Manteau",
    5: "Sandale", 6: "Chemise", 7: "Sneaker", 8: "Sac", 9: "Botte",
    10: "Veste", 11: "Jupe", 12: "Short", 13: "Chaussures", 14: "Accessoire"
}

def convert_heic_to_jpeg(image_path):
    """Convertit les HEIC en JPEG si nécessaire"""
    try:
        if isinstance(image_path, str) and image_path.lower().endswith('.heic'):
            # Conversion HEIC → JPEG
            img = Image.open(image_path)
            jpeg_path = image_path.replace('.heic', '.jpeg')
            img.convert('RGB').save(jpeg_path, 'JPEG')
            return jpeg_path
    except:
        pass
    return image_path

def preprocess_image(image):
    """Prétraitement robuste des images"""
    try:
        # Si c'est un chemin de fichier (HEIC)
        if isinstance(image, str):
            image = convert_heic_to_jpeg(image)
            image = Image.open(image)
        
        # Conversion en RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Redimensionnement
        image = image.resize((224, 224), Image.Resampling.LANCZOS)
        
        return image
        
    except Exception as e:
        raise Exception(f"Erreur prétraitement: {str(e)}")

def classify_fashion(image):
    """Classification avec gestion robuste des formats"""
    try:
        if image is None:
            return "❌ Veuillez uploader une image de vêtement"
        
        if processor is None or model is None:
            return "⚠️ Modèle en cours de chargement... Patientez 30s"
        
        # 📸 Gestion spéciale HEIC et formats complexes
        try:
            # Si l'image est un chemin temporaire (format HEIC)
            if isinstance(image, str) and ('gradio' in image or 'tmp' in image):
                if image.lower().endswith('.heic'):
                    # Conversion HEIC → JPEG
                    img = Image.open(image)
                    with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
                        img.convert('RGB').save(tmp.name, 'JPEG', quality=95)
                        processed_image = Image.open(tmp.name)
                        os.unlink(tmp.name)  # Nettoyage
                else:
                    processed_image = Image.open(image)
            else:
                # Image normale
                processed_image = image
            
            # Conversion en RGB si nécessaire
            if processed_image.mode != 'RGB':
                processed_image = processed_image.convert('RGB')
                
        except Exception as e:
            return f"❌ Format d'image non supporté: {str(e)}\n\n💡 Utilisez JPEG, PNG ou WebP"
        
        # 🔥 PRÉTRAITEMENT CORRECT
        processed_image = processed_image.resize((224, 224), Image.Resampling.LANCZOS)
        
        # Transformation pour le modèle
        inputs = processor(images=processed_image, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # 🔥 INFÉRENCE
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 📊 POST-TRAITEMENT
        probabilities = F.softmax(outputs.logits, dim=-1)
        top_probs, top_indices = torch.topk(probabilities, 5)
        
        # Conversion en résultats
        results = []
        for i in range(len(top_indices[0])):
            label_idx = top_indices[0][i].item()
            label_name = FASHION_LABELS.get(label_idx, f"Catégorie {label_idx}")