File size: 2,942 Bytes
fd50bed
5a90b4e
fd50bed
5a90b4e
3474c7b
5a90b4e
bf440a3
3474c7b
bf440a3
5a90b4e
 
3474c7b
5a90b4e
 
bf440a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3474c7b
bf440a3
 
 
 
 
 
 
 
 
5a90b4e
bf440a3
5a90b4e
bf440a3
 
3474c7b
bf440a3
 
 
 
5a90b4e
 
 
bf440a3
5a90b4e
 
3474c7b
 
 
 
 
5a90b4e
 
3474c7b
bf440a3
3474c7b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import os

# --- Chargement du modèle et du processeur ---
print("Loading model and processor...")
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")

def predict(image):
    """Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
    try:
        # Conversion vers RGB pour éviter les erreurs de canaux
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Pré-traitement de l'image
        inputs = processor(images=image, return_tensors="pt")
        
        # Prédiction
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
        
        # Application de softmax pour obtenir les probabilités
        probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
        top_probs, top_indices = torch.topk(probabilities, 5)  # Top 5 predictions
        
        # Formatage des résultats
        predictions = []
        for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
            pred_label = model.config.id2label[idx.item()]
            confidence = prob.item()
            if confidence > 0.1:  # Seuil de confiance à 10%
                predictions.append(f"{pred_label}: {confidence:.2%}")
        
        if not predictions:
            return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
        
        return "\n".join(predictions)
        
    except Exception as e:
        return f"Une erreur s'est produite lors du traitement: {str(e)}"

# Configuration de l'interface Gradio
title = "Fashion Item Classifier"
description = (
    "Upload an image of a clothing item, and I will classify it. "
    "This is a general-purpose model (ImageNet). For better accuracy on fashion items, "
    "a specialized model is needed."
)

# Création de l'interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Clothing Item"),
    outputs=gr.Textbox(label="Classification Results"),
    title=title,
    description=description,
    allow_flagging="never",
    examples=[
        ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=400"],  # T-shirt example
        ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"]   # Shoe example
    ]
)

# Lancement de l'application - CONFIGURATION SPÉCIFIQUE POUR HUGGING FACE SPACES
if __name__ == "__main__":
    # Cette configuration est cruciale pour Hugging Face Spaces
    demo.launch(
        debug=True,
        server_name="0.0.0.0",  # Important pour les conteneurs Docker
        server_port=int(os.environ.get("PORT", 7860))  Utilise le port de l'environnement Spaces
    )