| import os |
| import cv2 |
| import numpy as np |
| import pickle |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import tensorflow as tf |
| from tensorflow.keras import layers |
| from tensorflow.keras.models import load_model, Model |
| from tensorflow.keras.applications import EfficientNetV2B0 |
| from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess |
| from tensorflow.keras.preprocessing.sequence import pad_sequences |
| from tensorflow.keras.preprocessing.image import img_to_array |
| from tqdm import tqdm |
| import random |
| from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
| import tempfile |
| import traceback |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
|
|
| import gradio as gr |
| from PIL import Image |
| import pickle |
|
|
|
|
|
|
| |
| |
| |
|
|
| class ChannelAttention(layers.Layer): |
| def __init__(self, ratio=8, **kwargs): |
| super(ChannelAttention, self).__init__(**kwargs) |
| self.ratio = ratio |
|
|
| def build(self, input_shape): |
| self.gap = layers.GlobalAveragePooling1D() |
| self.gmp = layers.GlobalMaxPooling1D() |
| self.shared_mlp = tf.keras.Sequential([ |
| layers.Dense(units=1280 // self.ratio, activation='relu'), |
| layers.Dense(units=1280) |
| ]) |
| self.sigmoid = layers.Activation('sigmoid') |
| super(ChannelAttention, self).build(input_shape) |
|
|
| def call(self, inputs): |
| gap = self.gap(inputs) |
| gmp = self.gmp(inputs) |
| gap_mlp = self.shared_mlp(gap) |
| gmp_mlp = self.shared_mlp(gmp) |
| channel_attention = self.sigmoid(gap_mlp + gmp_mlp) |
| return inputs * tf.expand_dims(channel_attention, axis=1) |
|
|
| def get_config(self): |
| config = super(ChannelAttention, self).get_config() |
| config.update({'ratio': self.ratio}) |
| return config |
|
|
| @classmethod |
| def from_config(cls, config): |
| return cls(**config) |
|
|
|
|
|
|
| class SpatialAttention(layers.Layer): |
| def __init__(self, **kwargs): |
| super(SpatialAttention, self).__init__(**kwargs) |
|
|
| def build(self, input_shape): |
| self.conv = layers.Conv1D(1, kernel_size=3, padding='same', activation='sigmoid') |
| super(SpatialAttention, self).build(input_shape) |
|
|
| def call(self, inputs): |
| spatial_attention = self.conv(inputs) |
| return inputs * spatial_attention |
|
|
| def get_config(self): |
| return super(SpatialAttention, self).get_config() |
|
|
| @classmethod |
| def from_config(cls, config): |
| return cls(**config) |
|
|
|
|
|
|
| |
| |
| |
|
|
| def load_caption_model(model_path): |
| custom_objects = { |
| 'ChannelAttention': ChannelAttention, |
| 'SpatialAttention': SpatialAttention |
| } |
| model = load_model(model_path, custom_objects=custom_objects) |
| print("✅ Đã load model thành công!") |
| return model |
|
|
|
|
| def load_tokenizer_and_config(tokenizer_path, config_path): |
| with open(tokenizer_path, 'rb') as f: |
| tokenizer = pickle.load(f) |
| with open(config_path, 'rb') as f: |
| config = pickle.load(f) |
| return tokenizer, config['max_length'], config['vocab_size'] |
|
|
|
|
| |
| |
| |
|
|
| def load_feature_extractor(): |
| base_model = EfficientNetV2B0(include_top=False, weights='imagenet', pooling='avg') |
| return Model(inputs=base_model.input, outputs=base_model.output) |
|
|
|
|
| def extract_features_from_image(image_path, extractor): |
| image = cv2.imread(image_path) |
| if image is None: |
| print(f"❌ Không đọc được ảnh: {image_path}") |
| return None |
| image = cv2.resize(image, (224, 224)) |
| image = img_to_array(image) |
| image = np.expand_dims(image, axis=0) |
| image = efficientnet_preprocess(image) |
| feature = extractor.predict(image, verbose=0) |
| return feature |
|
|
|
|
| |
| |
| |
|
|
| def generate_caption(model, tokenizer, image_features, max_length): |
| in_text = 'startseq' |
| for _ in range(max_length): |
| sequence = tokenizer.texts_to_sequences([in_text])[0] |
| sequence = pad_sequences([sequence], maxlen=max_length) |
| yhat = model.predict([image_features, sequence], verbose=0) |
| yhat = np.argmax(yhat) |
| word = tokenizer.index_word.get(yhat) |
| if word is None or word == 'endseq': |
| break |
| in_text += ' ' + word |
| return in_text.replace('startseq ', '') |
|
|
|
|
| |
| |
| |
|
|
| MODEL_REPO = "slyviee/img_cap" |
|
|
| |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras") |
| tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl") |
| config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl") |
|
|
| model = None |
| tokenizer = None |
| max_length = None |
| vocab_size = None |
| extractor = None |
| ready = False |
| startup_error = "" |
|
|
|
|
| def _startup(): |
| global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error |
| try: |
| |
| missing = [p for p in [model_path, tokenizer_path, config_path] if not Path(p).exists()] |
| if missing: |
| startup_error = "Thiếu tệp: " + ", ".join(missing) |
| ready = False |
| return |
|
|
| print("🔄 Đang tải model...") |
| model = load_caption_model(model_path) |
| print("✅ Model đã được tải.") |
|
|
| print("🔄 Đang tải tokenizer và config...") |
| tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path) |
| print("✅ Tokenizer và config đã được tải.") |
|
|
| print("🔄 Đang tải feature extractor...") |
| extractor = load_feature_extractor() |
| print("✅ Feature extractor đã được tải.") |
|
|
| ready = True |
| except Exception as e: |
| startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}" |
| ready = False |
|
|
|
|
| def predict(pil_image: Image.Image): |
| if not ready: |
| return f"Hệ thống chưa sẵn sàng. {startup_error or 'Thiếu model/tokenizer/config.'}" |
|
|
| try: |
| |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
| pil_image.convert("RGB").save(tmp.name, format="JPEG") |
| tmp_path = tmp.name |
|
|
| features = extract_features_from_image(tmp_path, extractor) |
| os.unlink(tmp_path) |
|
|
| if features is None: |
| return "Không đọc được ảnh đầu vào." |
| caption = generate_caption(model, tokenizer, features, max_length) |
| return caption |
| except Exception as e: |
| return f"Lỗi trong quá trình dự đoán: {e}\n{traceback.format_exc()}" |
|
|
| DESCRIPTION = ( |
| "Upload ảnh và nhận caption sinh ra bởi mô hình. " |
| ) |
|
|
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Ảnh vào"), |
| outputs=gr.Textbox(label="Caption"), |
| title="Image Captioning — Gradio", |
| description=DESCRIPTION, |
| ) |
|
|
| if __name__ == '__main__': |
| _startup() |
| demo.launch() |