Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| import os | |
| # --- Chargement du modèle et du processeur --- | |
| print("Loading model and processor...") | |
| model_name = "google/vit-base-patch16-224" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| print("Model loaded successfully!") | |
| def predict(image): | |
| """Fonction de prédiction avec gestion d'erreurs et seuil de confiance""" | |
| try: | |
| # Conversion vers RGB pour éviter les erreurs de canaux | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Pré-traitement de l'image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Prédiction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Application de softmax pour obtenir les probabilités | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions | |
| # Formatage des résultats | |
| predictions = [] | |
| for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): | |
| pred_label = model.config.id2label[idx.item()] | |
| confidence = prob.item() | |
| if confidence > 0.1: # Seuil de confiance à 10% | |
| predictions.append(f"{pred_label}: {confidence:.2%}") | |
| if not predictions: | |
| return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire." | |
| return "\n".join(predictions) | |
| except Exception as e: | |
| return f"Une erreur s'est produite lors du traitement: {str(e)}" | |
| # Configuration de l'interface Gradio | |
| title = "Fashion Item Classifier" | |
| description = ( | |
| "Upload an image of a clothing item, and I will classify it. " | |
| "This is a general-purpose model (ImageNet). For better accuracy on fashion items, " | |
| "a specialized model is needed." | |
| ) | |
| # Création de l'interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Clothing Item"), | |
| outputs=gr.Textbox(label="Classification Results"), | |
| title=title, | |
| description=description, | |
| allow_flagging="never", | |
| examples=[ | |
| ["https://images.unsplash.com/photo-1552374196-c4e7ffc6e126?w=400"], # T-shirt example | |
| ["https://images.unsplash.com/photo-1543163521-1bf539c55dd2?w=400"] # Shoe example | |
| ] | |
| ) | |
| # Lancement de l'application - CONFIGURATION SPÉCIFIQUE POUR HUGGING FACE SPACES | |
| if __name__ == "__main__": | |
| # Cette configuration est cruciale pour Hugging Face Spaces | |
| demo.launch( | |
| debug=True, | |
| server_name="0.0.0.0", # Important pour les conteneurs Docker | |
| server_port=int(os.environ.get("PORT", 7860)) Utilise le port de l'environnement Spaces | |
| ) |