ankz22 commited on
Commit
10add35
·
1 Parent(s): 6245485

add app and packages

Browse files
Files changed (2) hide show
  1. app.py +28 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+
6
+ extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
7
+ model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
8
+
9
+ POUBELLES = {
10
+ "banana": "biodéchets",
11
+ "plastic bottle": "plastique",
12
+ "can": "métal",
13
+ "apple": "biodéchets",
14
+ "paper towel": "papier",
15
+ "glass": "verre",
16
+ }
17
+
18
+ def classify_image(image):
19
+ inputs = extractor(images=image, return_tensors="pt")
20
+ with torch.no_grad():
21
+ logits = model(**inputs).logits
22
+ predicted_class_idx = logits.argmax(-1).item()
23
+ label = model.config.id2label[predicted_class_idx]
24
+
25
+ poubelle = POUBELLES.get(label.lower(), "inconnue")
26
+ return f"{label} → {poubelle}"
27
+
28
+ gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="text").launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ Pillow