cang1602004 commited on
Commit
64bd29c
·
verified ·
1 Parent(s): a72c926

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -19
app.py CHANGED
@@ -6,19 +6,28 @@ import tensorflow as tf
6
  from tensorflow.keras.applications.efficientnet import preprocess_input
7
  import requests
8
  from io import BytesIO
 
9
 
10
- # load model + label encoder
11
  MODEL_SAVE_PATH = "guava_model.keras"
12
  LABEL_ENCODER_PATH = "label_encoder.pkl"
13
 
 
14
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
15
  with open(LABEL_ENCODER_PATH, "rb") as f:
16
  label_encoder = pickle.load(f)
17
 
 
 
 
 
18
  IMG_SIZE = model.input_shape[1:3]
19
 
 
 
 
 
20
  def load_image_from_url(url):
21
- """Tải ảnh từ URL và return PIL."""
22
  try:
23
  resp = requests.get(url, timeout=5)
24
  img = Image.open(BytesIO(resp.content)).convert("RGB")
@@ -26,21 +35,25 @@ def load_image_from_url(url):
26
  except:
27
  return None
28
 
29
- def predict_fn(img, url):
30
- """img: numpy image (upload), url: string"""
31
-
32
- # Ưu tiên dùng URL nếu có
 
 
 
33
  if url and url.strip() != "":
34
  img_pil = load_image_from_url(url)
35
  if img_pil is None:
36
- return "❌ Không tải được ảnh từ URL!", None
37
  else:
38
- # sử dụng ảnh upload
39
  if img is None:
40
- return "❌ Chưa cung cấp ảnh!", None
41
  img_pil = Image.fromarray(img).convert("RGB")
42
 
43
- # preprocess
 
 
44
  img_resized = img_pil.resize(IMG_SIZE)
45
  arr = np.array(img_resized).astype("float32")
46
  arr = preprocess_input(arr)
@@ -49,24 +62,51 @@ def predict_fn(img, url):
49
  preds = model.predict(arr)
50
  idx = int(np.argmax(preds, axis=1)[0])
51
  confidence = float(np.max(preds))
52
- label = label_encoder.inverse_transform([idx])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- return f"✅ {label} ", img_pil
 
55
 
56
 
57
- # Giao diện Gradio
 
 
58
  demo = gr.Interface(
59
- fn=predict_fn,
60
  inputs=[
61
  gr.Image(type="numpy", label="Upload Image"),
62
- gr.Textbox(label="Hoặc dán URL ảnh online")
63
  ],
64
  outputs=[
65
- gr.Textbox(label="Prediction"),
66
- gr.Image(label="Preview Image")
 
 
67
  ],
68
- title="Guava Classifier",
69
- description="Upload ảnh Ổi hoặc nhập URL ảnh để phân loại."
70
  )
71
 
72
  demo.launch(inline=True)
 
6
  from tensorflow.keras.applications.efficientnet import preprocess_input
7
  import requests
8
  from io import BytesIO
9
+ from ultralytics import YOLO
10
 
11
+ # ==== Load models ====
12
  MODEL_SAVE_PATH = "guava_model.keras"
13
  LABEL_ENCODER_PATH = "label_encoder.pkl"
14
 
15
+ # ENB0
16
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
17
  with open(LABEL_ENCODER_PATH, "rb") as f:
18
  label_encoder = pickle.load(f)
19
 
20
+ # YOLOv8 (model bạn đã train)
21
+ YOLO_MODEL_PATH = "yolov8_guava.pt"
22
+ yolo_model = YOLO(YOLO_MODEL_PATH)
23
+
24
  IMG_SIZE = model.input_shape[1:3]
25
 
26
+
27
+ # ======================================================
28
+ # Load image (Upload or URL)
29
+ # ======================================================
30
  def load_image_from_url(url):
 
31
  try:
32
  resp = requests.get(url, timeout=5)
33
  img = Image.open(BytesIO(resp.content)).convert("RGB")
 
35
  except:
36
  return None
37
 
38
+
39
+ # ======================================================
40
+ # Prediction function (ENB0 + YOLOv8)
41
+ # ======================================================
42
+ def compare_models(img, url):
43
+
44
+ # --- Ưu tiên URL ---
45
  if url and url.strip() != "":
46
  img_pil = load_image_from_url(url)
47
  if img_pil is None:
48
+ return "❌ Không tải được ảnh từ URL!", None, None, None
49
  else:
 
50
  if img is None:
51
+ return "❌ Chưa cung cấp ảnh!", None, None, None
52
  img_pil = Image.fromarray(img).convert("RGB")
53
 
54
+ # ======================================================
55
+ # EfficientNetB0 Prediction
56
+ # ======================================================
57
  img_resized = img_pil.resize(IMG_SIZE)
58
  arr = np.array(img_resized).astype("float32")
59
  arr = preprocess_input(arr)
 
62
  preds = model.predict(arr)
63
  idx = int(np.argmax(preds, axis=1)[0])
64
  confidence = float(np.max(preds))
65
+ label_enb0 = label_encoder.inverse_transform([idx])[0]
66
+ enb0_text = f"{label_enb0} (Conf {confidence:.2f})"
67
+
68
+ # ======================================================
69
+ # YOLOv8 Prediction
70
+ # ======================================================
71
+ results = yolo_model(img_pil)
72
+ result = results[0]
73
+
74
+ # Lấy label + conf cao nhất
75
+ if len(result.boxes) > 0:
76
+ best = result.boxes[0]
77
+ yolo_label_id = int(best.cls[0])
78
+ yolo_conf = float(best.conf[0])
79
+ yolo_label = yolo_model.model.names[yolo_label_id]
80
+ yolo_text = f"{yolo_label} (Conf {yolo_conf:.2f})"
81
+
82
+ # Vẽ bounding box
83
+ img_yolo = result.plot() # return np array BGR
84
+ img_yolo = Image.fromarray(img_yolo[..., ::-1]) # convert to RGB
85
+ else:
86
+ yolo_text = "Không phát hiện!"
87
+ img_yolo = img_pil
88
 
89
+ # return 4 outputs
90
+ return enb0_text, yolo_text, img_pil, img_yolo
91
 
92
 
93
+ # ======================================================
94
+ # Gradio UI
95
+ # ======================================================
96
  demo = gr.Interface(
97
+ fn=compare_models,
98
  inputs=[
99
  gr.Image(type="numpy", label="Upload Image"),
100
+ gr.Textbox(label="Hoặc nhập URL ảnh")
101
  ],
102
  outputs=[
103
+ gr.Textbox(label="ENB0 Prediction"),
104
+ gr.Textbox(label="YOLOv8 Prediction"),
105
+ gr.Image(label="Original Image"),
106
+ gr.Image(label="YOLOv8 Detection Image")
107
  ],
108
+ title="Guava Classifier — YOLOv8 vs EfficientNetB0",
109
+ description="So sánh kết quả phân loại giữa YOLOv8 EfficientNetB0."
110
  )
111
 
112
  demo.launch(inline=True)