Spaces:
Sleeping
Sleeping
File size: 2,942 Bytes
fd50bed 5a90b4e fd50bed 5a90b4e 3474c7b 5a90b4e bf440a3 3474c7b bf440a3 5a90b4e 3474c7b 5a90b4e bf440a3 3474c7b bf440a3 5a90b4e bf440a3 5a90b4e bf440a3 3474c7b bf440a3 5a90b4e bf440a3 5a90b4e 3474c7b 5a90b4e 3474c7b bf440a3 3474c7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import os
# --- Chargement du modèle et du processeur ---
print("Loading model and processor...")
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")
def predict(image):
"""Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
try:
# Conversion vers RGB pour éviter les erreurs de canaux
if image.mode != 'RGB':
image = image.convert('RGB')
# Pré-traitement de l'image
inputs = processor(images=image, return_tensors="pt")
# Prédiction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Application de softmax pour obtenir les probabilités
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions
# Formatage des résultats
predictions = []
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
pred_label = model.config.id2label[idx.item()]
confidence = prob.item()
if confidence > 0.1: # Seuil de confiance à 10%
predictions.append(f"{pred_label}: {confidence:.2%}")
if not predictions:
return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
return "\n".join(predictions)
except Exception as e:
return f"Une erreur s'est produite lors du traitement: {str(e)}"
# Configuration de l'interface Gradio
title = "Fashion Item Classifier"
description = (
"Upload an image of a clothing item, and I will classify it. "
"This is a general-purpose model (ImageNet). For better accuracy on fashion items, "
"a specialized model is needed."
)
# Création de l'interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Clothing Item"),
outputs=gr.Textbox(label="Classification Results"),
title=title,
description=description,
allow_flagging="never",
examples=[
["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=400"], # T-shirt example
["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"] # Shoe example
]
)
# Lancement de l'application - CONFIGURATION SPÉCIFIQUE POUR HUGGING FACE SPACES
if __name__ == "__main__":
# Cette configuration est cruciale pour Hugging Face Spaces
demo.launch(
debug=True,
server_name="0.0.0.0", # Important pour les conteneurs Docker
server_port=int(os.environ.get("PORT", 7860)) Utilise le port de l'environnement Spaces
) |