File size: 4,157 Bytes
fd50bed
5a90b4e
fd50bed
5a90b4e
3474c7b
5a90b4e
bf440a3
3474c7b
bf440a3
5a90b4e
 
3474c7b
5a90b4e
 
bf440a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10f3b61
 
 
bf440a3
 
10f3b61
 
bf440a3
10f3b61
 
bf440a3
10f3b61
 
 
 
 
bf440a3
 
10f3b61
bf440a3
10f3b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b06362e
10f3b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a90b4e
10f3b61
bf440a3
3474c7b
 
ae65e60
 
3474c7b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import os

# --- Chargement du modèle et du processeur ---
print("Loading model and processor...")
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")

def predict(image):
    """Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
    try:
        # Conversion vers RGB pour éviter les erreurs de canaux
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Pré-traitement de l'image
        inputs = processor(images=image, return_tensors="pt")
        
        # Prédiction
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
        
        # Application de softmax pour obtenir les probabilités
        probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
        top_probs, top_indices = torch.topk(probabilities, 5)  # Top 5 predictions
        
        # Formatage des résultats sous forme de dictionnaire pour l'affichage
        results = {}
        for prob, idx in zip(top_probs, top_indices):
            pred_label = model.config.id2label[idx.item()]
            confidence = prob.item()
            if confidence > 0.01:  # Seuil de confiance à 1%
                results[pred_label] = confidence
        
        if not results:
            return {"Aucune prédiction fiable": 0.0}, "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
        
        # Créer un message de résultat
        top_prediction = list(results.items())[0]
        message = f"🏷️ Prédiction principale: {top_prediction[0]} ({top_prediction[1]:.2%})"
        
        return results, message
        
    except Exception as e:
        return {"Erreur": 0.0}, f"Une erreur s'est produite: {str(e)}"

# Interface Gradio améliorée
with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 👗 Fashion Item Classifier")
    gr.Markdown("Téléchargez une image de vêtement pour le classer automatiquement")
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                type="pil", 
                label="Image du vêtement",
                height=300,
                sources=["upload", "webcam", "clipboard"]
            )
            upload_btn = gr.Button("🚀 Analyser l'image", variant="primary")
        
        with gr.Column(scale=1):
            label_output = gr.Label(
                label="Résultats de classification",
                num_top_classes=5
            )
            text_output = gr.Textbox(
                label="Conclusion",
                interactive=False
            )
    
    # Exemples
    gr.Examples(
        examples=[
            ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=300"],  # T-shirt
            ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=300"],  # Chaussures
            ["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=300"]   # Robe
        ],
        inputs=image_input,
        label="Exemples d'images à tester"
    )
    
    # Instructions
    gr.Markdown("""
    ### 📋 Instructions
    - Téléchargez une image claire d'un vêtement
    - L'image doit montrer le vêtement de face
    - Fond uni recommandé pour de meilleurs résultats
    - Cliquez sur 'Analyser l'image' pour obtenir la classification
    """)
    
    # Liaison du bouton
    upload_btn.click(
        fn=predict,
        inputs=image_input,
        outputs=[label_output, text_output]
    )
    
    # Liaison aussi quand on upload une image
    image_input.upload(
        fn=predict,
        inputs=image_input,
        outputs=[label_output, text_output]
    )

# Lancement de l'application
if __name__ == "__main__":
    demo.launch(
        debug=True,
        server_name="0.0.0.0",
        server_port=int(os.environ.get("PORT", 7860))
    )