cang1602004's picture
Update app.py
2df679e verified
raw
history blame
2.01 kB
import gradio as gr
from PIL import Image
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input
import requests
from io import BytesIO
# load model + label encoder
MODEL_SAVE_PATH = "best_model.h5"
LABEL_ENCODER_PATH = "label_encoder.pkl"
model = tf.keras.models.load_model(MODEL_SAVE_PATH)
with open(LABEL_ENCODER_PATH, "rb") as f:
label_encoder = pickle.load(f)
IMG_SIZE = model.input_shape[1:3]
def load_image_from_url(url):
"""Tải ảnh từ URL và return PIL."""
try:
resp = requests.get(url, timeout=5)
img = Image.open(BytesIO(resp.content)).convert("RGB")
return img
except:
return None
def predict_fn(img, url):
"""img: numpy image (upload), url: string"""
# Ưu tiên dùng URL nếu có
if url and url.strip() != "":
img_pil = load_image_from_url(url)
if img_pil is None:
return "❌ Không tải được ảnh từ URL!", None
else:
# sử dụng ảnh upload
if img is None:
return "❌ Chưa cung cấp ảnh!", None
img_pil = Image.fromarray(img).convert("RGB")
# preprocess
img_resized = img_pil.resize(IMG_SIZE)
arr = np.array(img_resized).astype("float32")
arr = preprocess_input(arr)
arr = np.expand_dims(arr, 0)
preds = model.predict(arr)
idx = int(np.argmax(preds, axis=1)[0])
confidence = float(np.max(preds))
label = label_encoder.inverse_transform([idx])[0]
return f"✅ {label} ", img_pil
# Giao diện Gradio
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Textbox(label="Hoặc dán URL ảnh online")
],
outputs=[
gr.Textbox(label="Prediction"),
gr.Image(label="Preview Image")
],
title="Guava Classifier",
description="Upload ảnh Ổi hoặc nhập URL ảnh để phân loại."
)
demo.launch(inline=True)