cang1602004 commited on
Commit
b50b81d
·
verified ·
1 Parent(s): 24ec107

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -63
app.py CHANGED
@@ -6,33 +6,19 @@ import tensorflow as tf
6
  from tensorflow.keras.applications.efficientnet import preprocess_input
7
  import requests
8
  from io import BytesIO
9
- from ultralytics import YOLO
10
 
11
- # ======================================================
12
- # Load Models
13
- # ======================================================
14
  MODEL_SAVE_PATH = "guava_model.keras"
15
  LABEL_ENCODER_PATH = "label_encoder.pkl"
16
- YOLO_MODEL_PATH = "yolov8_guava.pt"
17
 
18
- # Load ENB0 model
19
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
20
-
21
- # Load label encoder
22
  with open(LABEL_ENCODER_PATH, "rb") as f:
23
  label_encoder = pickle.load(f)
24
 
25
- # Load YOLOv8
26
- yolo_model = YOLO(YOLO_MODEL_PATH)
27
-
28
- # Lấy đúng input size của model
29
- IMG_SIZE = (model.input_shape[1], model.input_shape[2])
30
-
31
 
32
- # ======================================================
33
- # Load image from URL
34
- # ======================================================
35
  def load_image_from_url(url):
 
36
  try:
37
  resp = requests.get(url, timeout=5)
38
  img = Image.open(BytesIO(resp.content)).convert("RGB")
@@ -40,73 +26,47 @@ def load_image_from_url(url):
40
  except:
41
  return None
42
 
43
-
44
- # ======================================================
45
- # Predict Function
46
- # ======================================================
47
- def compare_models(img, url):
48
-
49
- # --- Ưu tiên URL ---
50
- if url and url.strip():
51
  img_pil = load_image_from_url(url)
52
  if img_pil is None:
53
- return "❌ URL không hợp lệ!", None, None, None
54
  else:
 
55
  if img is None:
56
- return "❌ Chưa cung cấp ảnh!", None, None, None
57
-
58
  img_pil = Image.fromarray(img).convert("RGB")
59
 
60
- # ===================== ENB0 =======================
61
  img_resized = img_pil.resize(IMG_SIZE)
62
- arr = np.array(img_resized, dtype=np.float32)
63
  arr = preprocess_input(arr)
64
- arr = np.expand_dims(arr, axis=0)
65
 
66
  preds = model.predict(arr)
67
  idx = int(np.argmax(preds, axis=1)[0])
68
  confidence = float(np.max(preds))
69
  label = label_encoder.inverse_transform([idx])[0]
70
 
71
- enb0_text = f"{label} (Conf {confidence:.2f})"
72
-
73
- # ===================== YOLO =======================
74
- results = yolo_model(img_pil)
75
- result = results[0]
76
-
77
- if len(result.boxes) > 0:
78
- best = result.boxes[0]
79
- cls_id = int(best.cls[0])
80
- conf = float(best.conf[0])
81
- yolo_label = yolo_model.model.names[cls_id]
82
- yolo_text = f"{yolo_label} (Conf {conf:.2f})"
83
-
84
- img_yolo = result.plot()
85
- img_yolo = Image.fromarray(img_yolo[..., ::-1])
86
- else:
87
- yolo_text = "Không phát hiện!"
88
- img_yolo = img_pil
89
-
90
- return enb0_text, yolo_text, img_pil, img_yolo
91
 
92
 
93
- # ======================================================
94
- # Gradio UI
95
- # ======================================================
96
  demo = gr.Interface(
97
- fn=compare_models,
98
  inputs=[
99
  gr.Image(type="numpy", label="Upload Image"),
100
- gr.Textbox(label="Hoặc nhập URL ảnh")
101
  ],
102
  outputs=[
103
- gr.Textbox(label="ENB0 Prediction"),
104
- gr.Textbox(label="YOLOv8 Prediction"),
105
- gr.Image(label="Original Image"),
106
- gr.Image(label="YOLOv8 Detection Image")
107
  ],
108
- title="Guava Classifier — YOLOv8 vs EfficientNetB0",
109
- description="So sánh kết quả phân loại giữa YOLOv8 EfficientNetB0."
110
  )
111
 
112
- demo.launch()
 
6
  from tensorflow.keras.applications.efficientnet import preprocess_input
7
  import requests
8
  from io import BytesIO
 
9
 
10
+ # load model + label encoder
 
 
11
  MODEL_SAVE_PATH = "guava_model.keras"
12
  LABEL_ENCODER_PATH = "label_encoder.pkl"
 
13
 
 
14
  model = tf.keras.models.load_model(MODEL_SAVE_PATH)
 
 
15
  with open(LABEL_ENCODER_PATH, "rb") as f:
16
  label_encoder = pickle.load(f)
17
 
18
+ IMG_SIZE = model.input_shape[1:3]
 
 
 
 
 
19
 
 
 
 
20
  def load_image_from_url(url):
21
+ """Tải ảnh từ URL và return PIL."""
22
  try:
23
  resp = requests.get(url, timeout=5)
24
  img = Image.open(BytesIO(resp.content)).convert("RGB")
 
26
  except:
27
  return None
28
 
29
+ def predict_fn(img, url):
30
+ """img: numpy image (upload), url: string"""
31
+
32
+ # Ưu tiên dùng URL nếu có
33
+ if url and url.strip() != "":
 
 
 
34
  img_pil = load_image_from_url(url)
35
  if img_pil is None:
36
+ return "❌ Không tải được ảnh từ URL!", None
37
  else:
38
+ # sử dụng ảnh upload
39
  if img is None:
40
+ return "❌ Chưa cung cấp ảnh!", None
 
41
  img_pil = Image.fromarray(img).convert("RGB")
42
 
43
+ # preprocess
44
  img_resized = img_pil.resize(IMG_SIZE)
45
+ arr = np.array(img_resized).astype("float32")
46
  arr = preprocess_input(arr)
47
+ arr = np.expand_dims(arr, 0)
48
 
49
  preds = model.predict(arr)
50
  idx = int(np.argmax(preds, axis=1)[0])
51
  confidence = float(np.max(preds))
52
  label = label_encoder.inverse_transform([idx])[0]
53
 
54
+ return f"{label} ", img_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
+ # Giao diện Gradio
 
 
58
  demo = gr.Interface(
59
+ fn=predict_fn,
60
  inputs=[
61
  gr.Image(type="numpy", label="Upload Image"),
62
+ gr.Textbox(label="Hoặc dán URL ảnh online")
63
  ],
64
  outputs=[
65
+ gr.Textbox(label="Prediction"),
66
+ gr.Image(label="Preview Image")
 
 
67
  ],
68
+ title="Guava Classifier",
69
+ description="Upload ảnh Ổi hoặc nhập URL ảnh để phân loại."
70
  )
71
 
72
+ demo.launch(inline=True) sửa lại code này để tôi làm so sánh kết quả giữa yoloV8 và ENB0 deloy lên hunging face