Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,10 +2,8 @@ import os
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import pickle
|
| 5 |
-
import tempfile
|
| 6 |
-
import traceback
|
| 7 |
-
from pathlib import Path
|
| 8 |
from PIL import Image
|
|
|
|
| 9 |
import tensorflow as tf
|
| 10 |
from tensorflow.keras import layers
|
| 11 |
from tensorflow.keras.models import load_model, Model
|
|
@@ -13,24 +11,20 @@ from tensorflow.keras.applications import EfficientNetV2B0
|
|
| 13 |
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
|
| 14 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 15 |
from tensorflow.keras.preprocessing.image import img_to_array
|
| 16 |
-
from
|
| 17 |
-
import
|
|
|
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
# Khởi tạo tài nguyên toàn cục khi app start
|
| 23 |
-
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras")
|
| 24 |
-
tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl")
|
| 25 |
-
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl")
|
| 26 |
|
| 27 |
-
model = None
|
| 28 |
-
tokenizer = None
|
| 29 |
-
max_length = None
|
| 30 |
-
vocab_size = None
|
| 31 |
-
extractor = None
|
| 32 |
-
ready = False
|
| 33 |
-
startup_error = ""
|
| 34 |
|
| 35 |
# -----------------------------
|
| 36 |
# Custom attention layers
|
|
@@ -64,6 +58,10 @@ class ChannelAttention(layers.Layer):
|
|
| 64 |
config.update({'ratio': self.ratio})
|
| 65 |
return config
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class SpatialAttention(layers.Layer):
|
|
@@ -81,6 +79,11 @@ class SpatialAttention(layers.Layer):
|
|
| 81 |
def get_config(self):
|
| 82 |
return super(SpatialAttention, self).get_config()
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# -----------------------------
|
| 86 |
# Load model + tokenizer
|
|
@@ -92,6 +95,7 @@ def load_caption_model(model_path):
|
|
| 92 |
'SpatialAttention': SpatialAttention
|
| 93 |
}
|
| 94 |
model = load_model(model_path, custom_objects=custom_objects)
|
|
|
|
| 95 |
return model
|
| 96 |
|
| 97 |
|
|
@@ -115,6 +119,7 @@ def load_feature_extractor():
|
|
| 115 |
def extract_features_from_image(image_path, extractor):
|
| 116 |
image = cv2.imread(image_path)
|
| 117 |
if image is None:
|
|
|
|
| 118 |
return None
|
| 119 |
image = cv2.resize(image, (224, 224))
|
| 120 |
image = img_to_array(image)
|
|
@@ -143,9 +148,24 @@ def generate_caption(model, tokenizer, image_features, max_length):
|
|
| 143 |
|
| 144 |
|
| 145 |
# -----------------------------
|
| 146 |
-
#
|
| 147 |
# -----------------------------
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def _startup():
|
| 151 |
global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error
|
|
@@ -157,15 +177,23 @@ def _startup():
|
|
| 157 |
ready = False
|
| 158 |
return
|
| 159 |
|
|
|
|
| 160 |
model = load_caption_model(model_path)
|
|
|
|
|
|
|
|
|
|
| 161 |
tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path)
|
|
|
|
|
|
|
|
|
|
| 162 |
extractor = load_feature_extractor()
|
|
|
|
|
|
|
| 163 |
ready = True
|
| 164 |
except Exception as e:
|
| 165 |
startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}"
|
| 166 |
ready = False
|
| 167 |
|
| 168 |
-
_startup()
|
| 169 |
|
| 170 |
def predict(pil_image: Image.Image):
|
| 171 |
if not ready:
|
|
@@ -185,7 +213,7 @@ def predict(pil_image: Image.Image):
|
|
| 185 |
caption = generate_caption(model, tokenizer, features, max_length)
|
| 186 |
return caption
|
| 187 |
except Exception as e:
|
| 188 |
-
return f"Lỗi
|
| 189 |
|
| 190 |
DESCRIPTION = (
|
| 191 |
"Upload ảnh và nhận caption sinh ra bởi mô hình. "
|
|
@@ -200,5 +228,6 @@ demo = gr.Interface(
|
|
| 200 |
allow_flagging="never",
|
| 201 |
)
|
| 202 |
|
| 203 |
-
if __name__ ==
|
|
|
|
| 204 |
demo.launch()
|
|
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
import pickle
|
|
|
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
import tensorflow as tf
|
| 8 |
from tensorflow.keras import layers
|
| 9 |
from tensorflow.keras.models import load_model, Model
|
|
|
|
| 11 |
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
|
| 12 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 13 |
from tensorflow.keras.preprocessing.image import img_to_array
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import random
|
| 16 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 17 |
|
| 18 |
+
import tempfile
|
| 19 |
+
import traceback
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
|
| 23 |
+
import gradio as gr
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import pickle
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# -----------------------------
|
| 30 |
# Custom attention layers
|
|
|
|
| 58 |
config.update({'ratio': self.ratio})
|
| 59 |
return config
|
| 60 |
|
| 61 |
+
@classmethod
|
| 62 |
+
def from_config(cls, config):
|
| 63 |
+
return cls(**config)
|
| 64 |
+
|
| 65 |
|
| 66 |
|
| 67 |
class SpatialAttention(layers.Layer):
|
|
|
|
| 79 |
def get_config(self):
|
| 80 |
return super(SpatialAttention, self).get_config()
|
| 81 |
|
| 82 |
+
@classmethod
|
| 83 |
+
def from_config(cls, config):
|
| 84 |
+
return cls(**config)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
|
| 88 |
# -----------------------------
|
| 89 |
# Load model + tokenizer
|
|
|
|
| 95 |
'SpatialAttention': SpatialAttention
|
| 96 |
}
|
| 97 |
model = load_model(model_path, custom_objects=custom_objects)
|
| 98 |
+
print("✅ Đã load model thành công!")
|
| 99 |
return model
|
| 100 |
|
| 101 |
|
|
|
|
| 119 |
def extract_features_from_image(image_path, extractor):
|
| 120 |
image = cv2.imread(image_path)
|
| 121 |
if image is None:
|
| 122 |
+
print(f"❌ Không đọc được ảnh: {image_path}")
|
| 123 |
return None
|
| 124 |
image = cv2.resize(image, (224, 224))
|
| 125 |
image = img_to_array(image)
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
# -----------------------------
|
| 151 |
+
# Chạy test
|
| 152 |
# -----------------------------
|
| 153 |
|
| 154 |
+
MODEL_REPO = "slyviee/img_cap"
|
| 155 |
+
|
| 156 |
+
# Khởi tạo tài nguyên toàn cục khi app start
|
| 157 |
+
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras")
|
| 158 |
+
tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl")
|
| 159 |
+
config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl")
|
| 160 |
+
|
| 161 |
+
model = None
|
| 162 |
+
tokenizer = None
|
| 163 |
+
max_length = None
|
| 164 |
+
vocab_size = None
|
| 165 |
+
extractor = None
|
| 166 |
+
ready = False
|
| 167 |
+
startup_error = ""
|
| 168 |
+
|
| 169 |
|
| 170 |
def _startup():
|
| 171 |
global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error
|
|
|
|
| 177 |
ready = False
|
| 178 |
return
|
| 179 |
|
| 180 |
+
print("🔄 Đang tải model...")
|
| 181 |
model = load_caption_model(model_path)
|
| 182 |
+
print("✅ Model đã được tải.")
|
| 183 |
+
|
| 184 |
+
print("🔄 Đang tải tokenizer và config...")
|
| 185 |
tokenizer, max_length, vocab_size = load_tokenizer_and_config(tokenizer_path, config_path)
|
| 186 |
+
print("✅ Tokenizer và config đã được tải.")
|
| 187 |
+
|
| 188 |
+
print("🔄 Đang tải feature extractor...")
|
| 189 |
extractor = load_feature_extractor()
|
| 190 |
+
print("✅ Feature extractor đã được tải.")
|
| 191 |
+
|
| 192 |
ready = True
|
| 193 |
except Exception as e:
|
| 194 |
startup_error = f"Khởi tạo lỗi: {e}\n{traceback.format_exc()}"
|
| 195 |
ready = False
|
| 196 |
|
|
|
|
| 197 |
|
| 198 |
def predict(pil_image: Image.Image):
|
| 199 |
if not ready:
|
|
|
|
| 213 |
caption = generate_caption(model, tokenizer, features, max_length)
|
| 214 |
return caption
|
| 215 |
except Exception as e:
|
| 216 |
+
return f"Lỗi trong quá trình dự đoán: {e}\n{traceback.format_exc()}"
|
| 217 |
|
| 218 |
DESCRIPTION = (
|
| 219 |
"Upload ảnh và nhận caption sinh ra bởi mô hình. "
|
|
|
|
| 228 |
allow_flagging="never",
|
| 229 |
)
|
| 230 |
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
_startup()
|
| 233 |
demo.launch()
|