MODLI commited on
Commit
bf440a3
·
verified ·
1 Parent(s): 99bd7ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -30
app.py CHANGED
@@ -3,47 +3,76 @@ from transformers import ViTImageProcessor, ViTForImageClassification
3
  from PIL import Image
4
  import torch
5
 
6
- # --- CHANGEMENT CRITIQUE : Charger VOTRE modèle fine-tuné ---
7
- model_name = "MODLI/vit-fashion-classifier" # <--- REMPLACER par votre modèle entraîné
 
 
8
  processor = ViTImageProcessor.from_pretrained(model_name)
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
 
11
- # Fonction de prédiction avec seuil de confiance
12
  def predict(image):
13
- # Pré-traiter l'image exactement comme pendant l'entraînement
14
- inputs = processor(images=image, return_tensors="pt")
15
-
16
- # Prédire
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
- logits = outputs.logits
20
-
21
- # Appliquer softmax pour obtenir les probabilités
22
- probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
23
- top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions
24
-
25
- # Formater les résultats
26
- predictions = []
27
- for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
28
- pred_label = model.config.id2label[idx.item()]
29
- confidence = prob.item()
30
- # N'afficher que si la confiance est > 5%
31
- if confidence > 0.05 or i == 0: # Toujours afficher la première même si faible
32
- predictions.append((pred_label, f"{confidence:.2%}"))
33
-
34
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Interface Gradio
37
  title = "Fashion Item Classifier"
38
- description = "Upload an image of a clothing item, and I will classify it."
 
 
 
 
 
 
 
 
 
 
 
39
 
 
40
  demo = gr.Interface(
41
  fn=predict,
42
  inputs=gr.Image(type="pil", label="Upload Clothing Item"),
43
- outputs=gr.Label(label="Predictions", num_top_classes=5),
44
  title=title,
45
  description=description,
46
- examples=[["path_to_example_image_1.jpg"], ["path_to_example_image_2.jpg"]], # Ajoutez des exemples
 
47
  )
48
 
49
- demo.launch(debug=True)
 
 
 
3
  from PIL import Image
4
  import torch
5
 
6
+ # --- Chargement du modèle et du processeur ---
7
+ # Modèle de base ViT pré-entraîné sur ImageNet (beaucoup mieux que "beans")
8
+ # C'est une solution temporaire en attendant de fine-tuner sur le dataset mode
9
+ model_name = "google/vit-base-patch16-224"
10
  processor = ViTImageProcessor.from_pretrained(model_name)
11
  model = ViTForImageClassification.from_pretrained(model_name)
12
 
 
13
  def predict(image):
14
+ """Fonction de prédiction avec gestion d'erreurs et seuil de confiance"""
15
+ try:
16
+ # Conversion vers RGB pour éviter les erreurs de canaux
17
+ if image.mode != 'RGB':
18
+ image = image.convert('RGB')
19
+
20
+ # Pré-traitement de l'image
21
+ inputs = processor(images=image, return_tensors="pt")
22
+
23
+ # Prédiction
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits
27
+
28
+ # Application de softmax pour obtenir les probabilités
29
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
30
+ top_probs, top_indices = torch.topk(probabilities, 5) # Top 5 predictions
31
+
32
+ # Formatage des résultats
33
+ predictions = []
34
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
35
+ pred_label = model.config.id2label[idx.item()]
36
+ confidence = prob.item()
37
+ # N'afficher que si la confiance est > 10%
38
+ if confidence > 0.1:
39
+ predictions.append(f"{pred_label}: {confidence:.2%}")
40
+
41
+ # Si aucune prédiction n'a une confiance suffisante
42
+ if not predictions:
43
+ return "Je ne suis pas sûr de reconnaître cet item. Essayez avec une image plus claire."
44
+
45
+ return "\n".join(predictions)
46
+
47
+ except Exception as e:
48
+ return f"Une erreur s'est produite lors du traitement: {str(e)}"
49
 
50
+ # Configuration de l'interface Gradio
51
  title = "Fashion Item Classifier"
52
+ description = (
53
+ "Upload an image of a clothing item, and I will classify it. "
54
+ "⚠️ This is a general-purpose model. For better accuracy on fashion items, "
55
+ "a specialized model is needed."
56
+ )
57
+
58
+ # Exemples d'images (ajoutez vos propres exemples plus tard)
59
+ examples = [
60
+ ["shirt_example.jpg"],
61
+ ["shoe_example.jpg"],
62
+ ["dress_example.jpg"]
63
+ ]
64
 
65
+ # Création de l'interface
66
  demo = gr.Interface(
67
  fn=predict,
68
  inputs=gr.Image(type="pil", label="Upload Clothing Item"),
69
+ outputs=gr.Textbox(label="Classification Results"),
70
  title=title,
71
  description=description,
72
+ examples=examples,
73
+ allow_flagging="never"
74
  )
75
 
76
+ # Lancement de l'application
77
+ if __name__ == "__main__":
78
+ demo.launch(debug=True, share=False)