Spaces:
Sleeping
Sleeping
File size: 4,157 Bytes
fd50bed 5a90b4e fd50bed 5a90b4e 3474c7b 5a90b4e bf440a3 3474c7b bf440a3 5a90b4e 3474c7b 5a90b4e bf440a3 10f3b61 bf440a3 10f3b61 bf440a3 10f3b61 bf440a3 10f3b61 bf440a3 10f3b61 bf440a3 10f3b61 b06362e 10f3b61 5a90b4e 10f3b61 bf440a3 3474c7b ae65e60 3474c7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import os
# --- Chargement du modèle et du processeur ---
print("Loading model and processor...")
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")
def predict(image):
"""Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
try:
# Conversion vers RGB pour éviter les erreurs de canaux
if image.mode != 'RGB':
image = image.convert('RGB')
# Pré-traitement de l'image
inputs = processor(images=image, return_tensors="pt")
# Prédiction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Application de softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions
# Formatage des résultats sous forme de dictionnaire pour l'affichage
results = {}
for prob, idx in zip(top_probs, top_indices):
pred_label = model.config.id2label[idx.item()]
confidence = prob.item()
if confidence > 0.01: # Seuil de confiance à 1%
results[pred_label] = confidence
if not results:
return {"Aucune prédiction fiable": 0.0}, "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
# Créer un message de résultat
top_prediction = list(results.items())[0]
message = f"🏷️ Prédiction principale: {top_prediction[0]} ({top_prediction[1]:.2%})"
return results, message
except Exception as e:
return {"Erreur": 0.0}, f"Une erreur s'est produite: {str(e)}"
# Interface Gradio améliorée
with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 👗 Fashion Item Classifier")
gr.Markdown("Téléchargez une image de vêtement pour le classer automatiquement")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Image du vêtement",
height=300,
sources=["upload", "webcam", "clipboard"]
)
upload_btn = gr.Button("🚀 Analyser l'image", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(
label="Résultats de classification",
num_top_classes=5
)
text_output = gr.Textbox(
label="Conclusion",
interactive=False
)
# Exemples
gr.Examples(
examples=[
["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=300"], # T-shirt
["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=300"], # Chaussures
["https://images.unsplash.com/photo-1594633312681-425c7b97ccd1?w=300"] # Robe
],
inputs=image_input,
label="Exemples d'images à tester"
)
# Instructions
gr.Markdown("""
### 📋 Instructions
- Téléchargez une image claire d'un vêtement
- L'image doit montrer le vêtement de face
- Fond uni recommandé pour de meilleurs résultats
- Cliquez sur 'Analyser l'image' pour obtenir la classification
""")
# Liaison du bouton
upload_btn.click(
fn=predict,
inputs=image_input,
outputs=[label_output, text_output]
)
# Liaison aussi quand on upload une image
image_input.upload(
fn=predict,
inputs=image_input,
outputs=[label_output, text_output]
)
# Lancement de l'application
if __name__ == "__main__":
demo.launch(
debug=True,
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860))
) |