File size: 2,010 Bytes
1f46785
 
2df679e
 
 
3f94404
2df679e
 
1f46785
b50b81d
8b4550f
2df679e
83fbb3a
2df679e
 
 
d24f899
b50b81d
64bd29c
2df679e
b50b81d
2df679e
 
 
 
 
 
3f94404
b50b81d
 
 
 
 
2df679e
 
b50b81d
2df679e
b50b81d
2df679e
b50b81d
2df679e
3f94404
b50b81d
2df679e
b50b81d
3f94404
b50b81d
1f46785
2df679e
 
 
9e5445d
64bd29c
b50b81d
d24f899
1f46785
b50b81d
1f46785
b50b81d
2df679e
 
b50b81d
2df679e
1f46785
b50b81d
 
1f46785
b50b81d
 
1f46785
 
1a3a883
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from PIL import Image
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input
import requests
from io import BytesIO

# load model + label encoder
MODEL_SAVE_PATH = "guava_model.keras"
LABEL_ENCODER_PATH = "label_encoder.pkl"

model = tf.keras.models.load_model(MODEL_SAVE_PATH)
with open(LABEL_ENCODER_PATH, "rb") as f:
    label_encoder = pickle.load(f)

IMG_SIZE = model.input_shape[1:3]

def load_image_from_url(url):
    """Tải ảnh từ URL và return PIL."""
    try:
        resp = requests.get(url, timeout=5)
        img = Image.open(BytesIO(resp.content)).convert("RGB")
        return img
    except:
        return None

def predict_fn(img, url):
    """img: numpy image (upload), url: string"""
    
    # Ưu tiên dùng URL nếu có
    if url and url.strip() != "":
        img_pil = load_image_from_url(url)
        if img_pil is None:
            return "❌ Không tải được ảnh từ URL!", None
    else:
        # sử dụng ảnh upload
        if img is None:
            return "❌ Chưa cung cấp ảnh!", None
        img_pil = Image.fromarray(img).convert("RGB")

    # preprocess
    img_resized = img_pil.resize(IMG_SIZE)
    arr = np.array(img_resized).astype("float32")
    arr = preprocess_input(arr)
    arr = np.expand_dims(arr, 0)

    preds = model.predict(arr)
    idx = int(np.argmax(preds, axis=1)[0])
    confidence = float(np.max(preds))
    label = label_encoder.inverse_transform([idx])[0]

    return f"✅ {label} ", img_pil


# Giao diện Gradio
demo = gr.Interface(
    fn=predict_fn,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Textbox(label="Hoặc dán URL ảnh online")
    ],
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Image(label="Preview Image")
    ],
    title="Guava Classifier",
    description="Upload ảnh Ổi hoặc nhập URL ảnh để phân loại."
)

demo.launch(inline=True)