import os import cv2 import numpy as np import pickle from PIL import Image import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras.models import load_model, Model from tensorflow.keras.applications import EfficientNetV2B0 from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.preprocessing.image import img_to_array from tqdm import tqdm import random from tensorflow.keras.preprocessing.sequence import pad_sequences import tempfile import traceback from pathlib import Path from huggingface_hub import hf_hub_download import gradio as gr from PIL import Image import pickle # ----------------------------- # Custom attention layers # ----------------------------- class ChannelAttention(layers.Layer): def __init__(self, ratio=8, **kwargs): super(ChannelAttention, self).__init__(**kwargs) self.ratio = ratio def build(self, input_shape): self.gap = layers.GlobalAveragePooling1D() self.gmp = layers.GlobalMaxPooling1D() self.shared_mlp = tf.keras.Sequential([ layers.Dense(units=1280 // self.ratio, activation='relu'), layers.Dense(units=1280) ]) self.sigmoid = layers.Activation('sigmoid') super(ChannelAttention, self).build(input_shape) def call(self, inputs): gap = self.gap(inputs) gmp = self.gmp(inputs) gap_mlp = self.shared_mlp(gap) gmp_mlp = self.shared_mlp(gmp) channel_attention = self.sigmoid(gap_mlp + gmp_mlp) return inputs * tf.expand_dims(channel_attention, axis=1) def get_config(self): config = super(ChannelAttention, self).get_config() config.update({'ratio': self.ratio}) return config @classmethod def from_config(cls, config): return cls(**config) class SpatialAttention(layers.Layer): def __init__(self, **kwargs): super(SpatialAttention, self).__init__(**kwargs) def build(self, input_shape): self.conv = layers.Conv1D(1, kernel_size=3, padding='same', activation='sigmoid') super(SpatialAttention, self).build(input_shape) def call(self, inputs): spatial_attention = self.conv(inputs) return inputs * spatial_attention def get_config(self): return super(SpatialAttention, self).get_config() @classmethod def from_config(cls, config): return cls(**config) # ----------------------------- # Load model + tokenizer # ----------------------------- def load_caption_model(model_path): custom_objects = { 'ChannelAttention': ChannelAttention, 'SpatialAttention': SpatialAttention } model = load_model(model_path, custom_objects=custom_objects) print("✅ Đã load model thành công!") return model def load_tokenizer_and_config(tokenizer_path, config_path): with open(tokenizer_path, 'rb') as f: tokenizer = pickle.load(f) with open(config_path, 'rb') as f: config = pickle.load(f) return tokenizer, config['max_length'], config['vocab_size'] # ----------------------------- # Feature extractor - EfficientNetV2B0 # ----------------------------- def load_feature_extractor(): base_model = EfficientNetV2B0(include_top=False, weights='imagenet', pooling='avg') return Model(inputs=base_model.input, outputs=base_model.output) def extract_features_from_image(image_path, extractor): image = cv2.imread(image_path) if image is None: print(f"❌ Không đọc được ảnh: {image_path}") return None image = cv2.resize(image, (224, 224)) image = img_to_array(image) image = np.expand_dims(image, axis=0) image = efficientnet_preprocess(image) feature = extractor.predict(image, verbose=0) return feature # ----------------------------- # Generate caption # ----------------------------- def generate_caption(model, tokenizer, image_features, max_length): in_text = 'startseq' for _ in range(max_length): sequence = tokenizer.texts_to_sequences([in_text])[0] sequence = pad_sequences([sequence], maxlen=max_length) yhat = model.predict([image_features, sequence], verbose=0) yhat = np.argmax(yhat) word = tokenizer.index_word.get(yhat) if word is None or word == 'endseq': break in_text += ' ' + word return in_text.replace('startseq ', '') # ----------------------------- # Chạy test # ----------------------------- MODEL_REPO = "dunglelele/img_cap" # Khởi tạo tài nguyên toàn cục khi app start model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras") tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl") config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl") model = None tokenizer = None max_length = None vocab_size = None extractor = None ready = False startup_error = "" def _startup(): global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error try: # Kiểm tra sự tồn tại của các tệp cần thiết missing = [p for p in [model_path, tokenizer_path, config_path] if not Path(p).exists()] if missing: startup_error = "Thiếu tệp: " + ", ".join(missing) ready = False return print("🔄 Đang tải model...") model = load_caption_model(model_path) print("✅ Model đã được tải.") print("🔄 Đang tải tokenizer và config...") tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path) print("✅ Tokenizer và config đã được tải.") print("🔄 Đang tải feature extractor...") extractor = load_feature_extractor() print("✅ Feature extractor đã được tải.") ready = True except Exception as e: startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}" ready = False def predict(pil_image: Image.Image): if not ready: return f"Hệ thống chưa sẵn sàng. {startup_error or 'Thiếu model/tokenizer/config.'}" try: # Lưu ảnh tạm để tái sử dụng hàm extract_features_from_image (đọc bằng cv2) with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: pil_image.convert("RGB").save(tmp.name, format="JPEG") tmp_path = tmp.name features = extract_features_from_image(tmp_path, extractor) os.unlink(tmp_path) if features is None: return "Không đọc được ảnh đầu vào." caption = generate_caption(model, tokenizer, features, max_length) return caption except Exception as e: return f"Lỗi trong quá trình dự đoán: {e}\n{traceback.format_exc()}" DESCRIPTION = ( "Upload ảnh và nhận caption sinh ra bởi mô hình. " ) demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Ảnh vào"), outputs=gr.Textbox(label="Caption"), title="Image Captioning — Gradio", description=DESCRIPTION ) if __name__ == '__main__': _startup() demo.launch()