cang1602004 commited on
Commit
2df679e
·
verified ·
1 Parent(s): 04aed9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -40
app.py CHANGED
@@ -1,59 +1,72 @@
1
  import gradio as gr
2
- import tensorflow as tf
3
- import numpy as np
4
  from PIL import Image
 
 
 
5
  from tensorflow.keras.applications.efficientnet import preprocess_input
 
 
6
 
7
- # ============================
8
- # Load TensorFlow SavedModel
9
- # ============================
10
- MODEL_PATH = "exported_model"
11
- IMG_SIZE = (224, 224)
12
- CLASS_NAMES = ["bad", "good", "very_good"]
13
 
14
- print("🔄 Loading SavedModel…")
15
- model = tf.saved_model.load(MODEL_PATH)
16
- infer = model.signatures["serving_default"]
17
- print("✅ Model loaded!")
18
 
 
19
 
20
- # ============================
21
- # Prediction Function
22
- # ============================
23
- def predict_guava_quality(image):
24
- if image is None:
25
- return "❌ Vui lòng tải ảnh!", 0.0
 
 
26
 
27
- img = Image.fromarray(image).convert("RGB")
28
- img = img.resize(IMG_SIZE)
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- arr = np.array(img).astype("float32")
 
 
31
  arr = preprocess_input(arr)
32
- arr = np.expand_dims(arr, axis=0)
33
-
34
- outputs = infer(tf.constant(arr))
35
- preds = list(outputs.values())[0].numpy()[0]
36
 
37
- idx = np.argmax(preds)
38
- confidence = float(preds[idx])
39
- label = CLASS_NAMES[idx]
 
40
 
41
- return f"🍈 {label}", confidence
42
 
43
 
44
- # ============================
45
- # Gradio UI
46
- # ============================
47
  demo = gr.Interface(
48
- fn=predict_guava_quality,
49
- inputs=gr.Image(type="numpy", label="Tải ảnh Ổi"),
 
 
 
50
  outputs=[
51
- gr.Textbox(label="Kết quả dự đoán"),
52
- gr.Number(label="Độ tin cậy (0–1)")
53
  ],
54
- title="Guava Quality Classifier",
55
- description="Phân loại chất lượng quả Ổi: very_good / good / bad"
56
  )
57
 
58
- if __name__ == "__main__":
59
- demo.launch()
 
1
  import gradio as gr
 
 
2
  from PIL import Image
3
+ import numpy as np
4
+ import pickle
5
+ import tensorflow as tf
6
  from tensorflow.keras.applications.efficientnet import preprocess_input
7
+ import requests
8
+ from io import BytesIO
9
 
10
+ # load model + label encoder
11
+ MODEL_SAVE_PATH = "best_model.h5"
12
+ LABEL_ENCODER_PATH = "label_encoder.pkl"
 
 
 
13
 
14
+ model = tf.keras.models.load_model(MODEL_SAVE_PATH)
15
+ with open(LABEL_ENCODER_PATH, "rb") as f:
16
+ label_encoder = pickle.load(f)
 
17
 
18
+ IMG_SIZE = model.input_shape[1:3]
19
 
20
+ def load_image_from_url(url):
21
+ """Tải ảnh từ URL và return PIL."""
22
+ try:
23
+ resp = requests.get(url, timeout=5)
24
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
25
+ return img
26
+ except:
27
+ return None
28
 
29
+ def predict_fn(img, url):
30
+ """img: numpy image (upload), url: string"""
31
+
32
+ # Ưu tiên dùng URL nếu có
33
+ if url and url.strip() != "":
34
+ img_pil = load_image_from_url(url)
35
+ if img_pil is None:
36
+ return "❌ Không tải được ảnh từ URL!", None
37
+ else:
38
+ # sử dụng ảnh upload
39
+ if img is None:
40
+ return "❌ Chưa cung cấp ảnh!", None
41
+ img_pil = Image.fromarray(img).convert("RGB")
42
 
43
+ # preprocess
44
+ img_resized = img_pil.resize(IMG_SIZE)
45
+ arr = np.array(img_resized).astype("float32")
46
  arr = preprocess_input(arr)
47
+ arr = np.expand_dims(arr, 0)
 
 
 
48
 
49
+ preds = model.predict(arr)
50
+ idx = int(np.argmax(preds, axis=1)[0])
51
+ confidence = float(np.max(preds))
52
+ label = label_encoder.inverse_transform([idx])[0]
53
 
54
+ return f" {label} ", img_pil
55
 
56
 
57
+ # Giao diện Gradio
 
 
58
  demo = gr.Interface(
59
+ fn=predict_fn,
60
+ inputs=[
61
+ gr.Image(type="numpy", label="Upload Image"),
62
+ gr.Textbox(label="Hoặc dán URL ảnh online")
63
+ ],
64
  outputs=[
65
+ gr.Textbox(label="Prediction"),
66
+ gr.Image(label="Preview Image")
67
  ],
68
+ title="Guava Classifier",
69
+ description="Upload ảnh Ổi hoặc nhập URL ảnh để phân loại."
70
  )
71
 
72
+ demo.launch(inline=True)