jatin1233232 commited on
Commit
0f215da
Β·
verified Β·
1 Parent(s): 654b153

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ from huggingface_hub import hf_hub_download
6
+ import gradio as gr
7
+
8
+ # β€” Model & labels download (runs once at startup) β€”
9
+ MODEL_PATH = hf_hub_download(
10
+ repo_id="Adriana213/vgg16-fruit-classifier",
11
+ filename="vgg16-fruit-classifier.h5"
12
+ )
13
+ LABELS_PATH = hf_hub_download(
14
+ repo_id="Adriana213/vgg16-fruit-classifier",
15
+ filename="class_labels.json"
16
+ )
17
+
18
+ # β€” Load model & labels β€”
19
+ model = tf.keras.models.load_model(MODEL_PATH)
20
+ with open(LABELS_PATH, "r") as f:
21
+ id2label = json.load(f)
22
+
23
+ # β€” Preprocessing helper β€”
24
+ def preprocess(img: Image.Image) -> np.ndarray:
25
+ img = img.resize((100, 100))
26
+ arr = np.array(img)
27
+ if arr.ndim == 2: # grayscale β†’ RGB
28
+ arr = np.stack([arr]*3, axis=-1)
29
+ arr = arr.astype("float32")
30
+ # VGG16 preprocessing: subtract mean RGB
31
+ arr[..., 0] -= 123.68
32
+ arr[..., 1] -= 116.779
33
+ arr[..., 2] -= 103.939
34
+ return np.expand_dims(arr, 0)
35
+
36
+ # β€” Prediction function β€”
37
+ def classify_fruit(img: Image.Image):
38
+ img = img.convert("RGB")
39
+ x = preprocess(img)
40
+ preds = model.predict(x)
41
+ idx = int(np.argmax(preds, axis=1)[0])
42
+ label = id2label[str(idx)]
43
+ confidence = float(np.max(preds))
44
+ return {label: round(confidence, 4)}
45
+
46
+ # β€” Gradio interface β€”
47
+ demo = gr.Interface(
48
+ fn=classify_fruit,
49
+ inputs=gr.Image(type="pil", label="Upload Fruit Image"),
50
+ outputs=gr.Label(num_top_classes=5, label="Top Prediction"),
51
+ title="πŸ‰ Fruit Classifier (131 types)",
52
+ description="Upload a photo of a fruit and get back its name with confidence."
53
+ )
54
+
55
+ if __name__ == "__main__":
56
+ demo.launch()