ankz22 commited on
Commit
e7436cd
·
1 Parent(s): e7ed4b7

Add application file

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -5,9 +5,11 @@ import torch
5
  import torch.nn.functional as F
6
  import pandas as pd
7
 
8
- extractor = AutoImageProcessor.from_pretrained("mrm8488/resnet50-finetuned-trashnet")
9
- model = AutoModelForImageClassification.from_pretrained("mrm8488/resnet50-finetuned-trashnet")
 
10
 
 
11
  POUBELLES = {
12
  "cardboard": "papier/carton",
13
  "glass": "verre",
@@ -17,8 +19,9 @@ POUBELLES = {
17
  "trash": "ordures ménagères",
18
  }
19
 
 
20
  def classify_image(image):
21
- inputs = extractor(images=image, return_tensors="pt")
22
  with torch.no_grad():
23
  logits = model(**inputs).logits
24
  probs = F.softmax(logits, dim=-1)
@@ -39,6 +42,7 @@ def classify_image(image):
39
 
40
  return pd.DataFrame(rows)
41
 
 
42
  gr.Interface(
43
  fn=classify_image,
44
  inputs=gr.Image(type="pil"),
 
5
  import torch.nn.functional as F
6
  import pandas as pd
7
 
8
+ # Chargement du processeur et du modèle
9
+ processor = AutoImageProcessor.from_pretrained("tribber93/my-trash-classification")
10
+ model = AutoModelForImageClassification.from_pretrained("tribber93/my-trash-classification")
11
 
12
+ # Dictionnaire de correspondance entre les labels et les types de poubelles
13
  POUBELLES = {
14
  "cardboard": "papier/carton",
15
  "glass": "verre",
 
19
  "trash": "ordures ménagères",
20
  }
21
 
22
+ # Fonction de classification de l'image
23
  def classify_image(image):
24
+ inputs = processor(images=image, return_tensors="pt")
25
  with torch.no_grad():
26
  logits = model(**inputs).logits
27
  probs = F.softmax(logits, dim=-1)
 
42
 
43
  return pd.DataFrame(rows)
44
 
45
+ # Création de l'interface Gradio
46
  gr.Interface(
47
  fn=classify_image,
48
  inputs=gr.Image(type="pil"),