import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import os # --- Chargement du modèle et du processeur --- print("Loading model and processor...") model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) print("Model loaded successfully!") def predict(image): """Fonction de prédiction avec gestion d'erreurs et seuil de confiance""" try: # Conversion vers RGB pour éviter les erreurs de canaux if image.mode != 'RGB': image = image.convert('RGB') # Pré-traitement de l'image inputs = processor(images=image, return_tensors="pt") # Prédiction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Application de softmax pour obtenir les probabilités probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions # Formatage des résultats predictions = [] for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): pred_label = model.config.id2label[idx.item()] confidence = prob.item() if confidence > 0.1: # Seuil de confiance à 10% predictions.append(f"{pred_label}: {confidence:.2%}") if not predictions: return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire." return "\n".join(predictions) except Exception as e: return f"Une erreur s'est produite lors du traitement: {str(e)}" # Configuration de l'interface Gradio title = "Fashion Item Classifier" description = ( "Upload an image of a clothing item, and I will classify it. " "This is a general-purpose model (ImageNet). For better accuracy on fashion items, " "a specialized model is needed." ) # Création de l'interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Clothing Item"), outputs=gr.Textbox(label="Classification Results"), title=title, description=description, allow_flagging="never", examples=[ ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=400"], # T-shirt example ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"] # Shoe example ] ) # Lancement de l'application - CONFIGURATION SPÉCIFIQUE POUR HUGGING FACE SPACES if __name__ == "__main__": # Cette configuration est cruciale pour Hugging Face Spaces demo.launch( debug=True, server_name="0.0.0.0", # Important pour les conteneurs Docker server_port=int(os.environ.get("PORT", 7860)) Utilise le port de l'environnement Spaces )