cang1602004 commited on
Commit
9e5445d
·
verified ·
1 Parent(s): 4e37ff7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -8,24 +8,29 @@ import requests
8
  from io import BytesIO
9
  from ultralytics import YOLO
10
 
11
- # ==== Load models ====
 
 
12
  MODEL_SAVE_PATH = "guava_model.keras"
13
  LABEL_ENCODER_PATH = "label_encoder.pkl"
 
14
 
15
- # ENB0
16
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
 
 
17
  with open(LABEL_ENCODER_PATH, "rb") as f:
18
  label_encoder = pickle.load(f)
19
 
20
- # YOLOv8 (model bạn đã train)
21
- YOLO_MODEL_PATH = "yolov8_guava.pt"
22
  yolo_model = YOLO(YOLO_MODEL_PATH)
23
 
24
- IMG_SIZE = model.input_shape[1:3]
 
25
 
26
 
27
  # ======================================================
28
- # Load image (Upload or URL)
29
  # ======================================================
30
  def load_image_from_url(url):
31
  try:
@@ -37,61 +42,56 @@ def load_image_from_url(url):
37
 
38
 
39
  # ======================================================
40
- # Prediction function (ENB0 + YOLOv8)
41
  # ======================================================
42
  def compare_models(img, url):
43
 
44
  # --- Ưu tiên URL ---
45
- if url and url.strip() != "":
46
  img_pil = load_image_from_url(url)
47
  if img_pil is None:
48
- return "❌ Không tải được ảnh từ URL!", None, None, None
49
  else:
50
  if img is None:
51
  return "❌ Chưa cung cấp ảnh!", None, None, None
 
52
  img_pil = Image.fromarray(img).convert("RGB")
53
 
54
- # ======================================================
55
- # EfficientNetB0 Prediction
56
- # ======================================================
57
  img_resized = img_pil.resize(IMG_SIZE)
58
- arr = np.array(img_resized).astype("float32")
59
  arr = preprocess_input(arr)
60
- arr = np.expand_dims(arr, 0)
61
 
62
  preds = model.predict(arr)
63
  idx = int(np.argmax(preds, axis=1)[0])
64
  confidence = float(np.max(preds))
65
- label_enb0 = label_encoder.inverse_transform([idx])[0]
66
- enb0_text = f"{label_enb0} (Conf {confidence:.2f})"
67
 
68
- # ======================================================
69
- # YOLOv8 Prediction
70
- # ======================================================
71
  results = yolo_model(img_pil)
72
  result = results[0]
73
 
74
- # Lấy label + conf cao nhất
75
  if len(result.boxes) > 0:
76
  best = result.boxes[0]
77
- yolo_label_id = int(best.cls[0])
78
- yolo_conf = float(best.conf[0])
79
- yolo_label = yolo_model.model.names[yolo_label_id]
80
- yolo_text = f"{yolo_label} (Conf {yolo_conf:.2f})"
81
-
82
- # Vẽ bounding box
83
- img_yolo = result.plot() # return np array BGR
84
- img_yolo = Image.fromarray(img_yolo[..., ::-1]) # convert to RGB
85
  else:
86
  yolo_text = "Không phát hiện!"
87
  img_yolo = img_pil
88
 
89
- # return 4 outputs
90
  return enb0_text, yolo_text, img_pil, img_yolo
91
 
92
 
93
  # ======================================================
94
- # Gradio UI
95
  # ======================================================
96
  demo = gr.Interface(
97
  fn=compare_models,
@@ -109,4 +109,4 @@ demo = gr.Interface(
109
  description="So sánh kết quả phân loại giữa YOLOv8 và EfficientNetB0."
110
  )
111
 
112
- demo.launch(inline=True)
 
8
  from io import BytesIO
9
  from ultralytics import YOLO
10
 
11
+ # ======================================================
12
+ # Load Models
13
+ # ======================================================
14
  MODEL_SAVE_PATH = "guava_model.keras"
15
  LABEL_ENCODER_PATH = "label_encoder.pkl"
16
+ YOLO_MODEL_PATH = "yolov8_guava.pt"
17
 
18
+ # Load ENB0 model
19
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
20
+
21
+ # Load label encoder
22
  with open(LABEL_ENCODER_PATH, "rb") as f:
23
  label_encoder = pickle.load(f)
24
 
25
+ # Load YOLOv8
 
26
  yolo_model = YOLO(YOLO_MODEL_PATH)
27
 
28
+ # Lấy đúng input size của model
29
+ IMG_SIZE = (model.input_shape[1], model.input_shape[2])
30
 
31
 
32
  # ======================================================
33
+ # Load image from URL
34
  # ======================================================
35
  def load_image_from_url(url):
36
  try:
 
42
 
43
 
44
  # ======================================================
45
+ # Predict Function
46
  # ======================================================
47
  def compare_models(img, url):
48
 
49
  # --- Ưu tiên URL ---
50
+ if url and url.strip():
51
  img_pil = load_image_from_url(url)
52
  if img_pil is None:
53
+ return "❌ URL không hợp lệ!", None, None, None
54
  else:
55
  if img is None:
56
  return "❌ Chưa cung cấp ảnh!", None, None, None
57
+
58
  img_pil = Image.fromarray(img).convert("RGB")
59
 
60
+ # ===================== ENB0 =======================
 
 
61
  img_resized = img_pil.resize(IMG_SIZE)
62
+ arr = np.array(img_resized, dtype=np.float32)
63
  arr = preprocess_input(arr)
64
+ arr = np.expand_dims(arr, axis=0)
65
 
66
  preds = model.predict(arr)
67
  idx = int(np.argmax(preds, axis=1)[0])
68
  confidence = float(np.max(preds))
69
+ label = label_encoder.inverse_transform([idx])[0]
 
70
 
71
+ enb0_text = f"{label} (Conf {confidence:.2f})"
72
+
73
+ # ===================== YOLO =======================
74
  results = yolo_model(img_pil)
75
  result = results[0]
76
 
 
77
  if len(result.boxes) > 0:
78
  best = result.boxes[0]
79
+ cls_id = int(best.cls[0])
80
+ conf = float(best.conf[0])
81
+ yolo_label = yolo_model.model.names[cls_id]
82
+ yolo_text = f"{yolo_label} (Conf {conf:.2f})"
83
+
84
+ img_yolo = result.plot()
85
+ img_yolo = Image.fromarray(img_yolo[..., ::-1])
 
86
  else:
87
  yolo_text = "Không phát hiện!"
88
  img_yolo = img_pil
89
 
 
90
  return enb0_text, yolo_text, img_pil, img_yolo
91
 
92
 
93
  # ======================================================
94
+ # Gradio UI
95
  # ======================================================
96
  demo = gr.Interface(
97
  fn=compare_models,
 
109
  description="So sánh kết quả phân loại giữa YOLOv8 và EfficientNetB0."
110
  )
111
 
112
+ demo.launch()