ankz22 commited on
Commit
0797cfd
·
1 Parent(s): 10add35

Add application file

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -2,27 +2,53 @@ import gradio as gr
2
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
  from PIL import Image
4
  import torch
 
 
5
 
6
- extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
7
- model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
8
 
9
  POUBELLES = {
10
- "banana": "biodéchets",
11
- "plastic bottle": "plastique",
12
- "can": "métal",
13
- "apple": "biodéchets",
14
- "paper towel": "papier",
15
  "glass": "verre",
 
 
 
 
 
 
 
 
16
  }
17
 
18
  def classify_image(image):
19
  inputs = extractor(images=image, return_tensors="pt")
20
  with torch.no_grad():
21
  logits = model(**inputs).logits
22
- predicted_class_idx = logits.argmax(-1).item()
23
- label = model.config.id2label[predicted_class_idx]
24
 
25
- poubelle = POUBELLES.get(label.lower(), "inconnue")
26
- return f"{label} → {poubelle}"
 
27
 
28
- gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
  from PIL import Image
4
  import torch
5
+ import torch.nn.functional as F
6
+ import pandas as pd
7
 
8
+ extractor = AutoFeatureExtractor.from_pretrained("nateraw/resnet50-trash-classifier")
9
+ model = AutoModelForImageClassification.from_pretrained("nateraw/resnet50-trash-classifier")
10
 
11
  POUBELLES = {
12
+ "plastic": "plastique",
 
 
 
 
13
  "glass": "verre",
14
+ "metal": "métal",
15
+ "paper": "papier",
16
+ "cardboard": "papier/carton",
17
+ "trash": "ordures ménagères",
18
+ "compost": "biodéchets",
19
+ "battery": "déchet dangereux",
20
+ "clothes": "textile",
21
+ # Ajoute d'autres si nécessaire
22
  }
23
 
24
  def classify_image(image):
25
  inputs = extractor(images=image, return_tensors="pt")
26
  with torch.no_grad():
27
  logits = model(**inputs).logits
28
+ probs = F.softmax(logits, dim=-1)
 
29
 
30
+ top_probs, top_idxs = torch.topk(probs, 3)
31
+ top_probs = top_probs.squeeze().tolist()
32
+ top_idxs = top_idxs.squeeze().tolist()
33
 
34
+ rows = []
35
+ for idx, prob in zip(top_idxs, top_probs):
36
+ label = model.config.id2label[idx]
37
+ poubelle = POUBELLES.get(label.lower(), "inconnue")
38
+ rows.append({
39
+ "Objet": label,
40
+ "Poubelle": poubelle,
41
+ "Confiance (%)": round(prob * 100, 2),
42
+ })
43
+
44
+ return pd.DataFrame(rows)
45
+
46
+ interface = gr.Interface(
47
+ fn=classify_image,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs=gr.Dataframe(),
50
+ title="🗑️ Classifieur de Déchets",
51
+ description="Dépose une image de déchet pour savoir dans quelle poubelle le trier."
52
+ )
53
+
54
+ interface.launch()