peterverebics commited on
Commit
0412ad6
·
1 Parent(s): 9a20a67
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +46 -0
  3. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from PIL import Image
4
+ import numpy as np
5
+ import requests
6
+ import os
7
+
8
+ # Ensure model folder exists
9
+ os.makedirs("model", exist_ok=True)
10
+
11
+ # Download the model from Hugging Face if not already present
12
+ model_path = "model/mobnet_model.keras"
13
+ if not os.path.exists(model_path):
14
+ url = "https://huggingface.co/ahmzakif/TrashNet-Classification/resolve/main/model/mobnet_model.keras"
15
+ r = requests.get(url)
16
+ with open(model_path, "wb") as f:
17
+ f.write(r.content)
18
+
19
+ # Load Keras model
20
+ model = tf.keras.models.load_model(model_path)
21
+
22
+ # TrashNet classes
23
+ classes = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
24
+
25
+ # Image preprocessing
26
+ def predict(image: Image.Image):
27
+ image = image.convert("RGB").resize((224, 224))
28
+ x = np.array(image, dtype=np.float32) / 255.0
29
+ x = np.expand_dims(x, axis=0)
30
+
31
+ preds = model.predict(x)[0]
32
+ scores = {classes[i]: float(preds[i]) for i in range(len(classes))}
33
+ top_class = max(scores, key=scores.get)
34
+
35
+ return {"prediction": top_class, "scores": scores}
36
+
37
+ # Gradio interface
38
+ iface = gr.Interface(
39
+ fn=predict,
40
+ inputs=gr.Image(type="pil"),
41
+ outputs="json",
42
+ title="TrashNet Classification API",
43
+ description="Upload an image of trash to get its classification."
44
+ )
45
+
46
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tqdm==4.66.5
3
+ imutils==0.5.4
4
+ numpy==1.26.4
5
+ pandas==2.0.3
6
+ pillow==10.4.0
7
+ matplotlib==3.7.3
8
+ seaborn==0.11.0
9
+ albumentations==1.4.1
10
+ opencv-python==4.10.0.84
11
+ tensorflow==2.15.1
12
+ keras==2.15.1
13
+ scikit-learn==1.2.2
14
+ wandb==0.19.1