slyviee commited on
Commit
0584e28
·
verified ·
1 Parent(s): 6399bee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -16,6 +16,22 @@ from tensorflow.keras.preprocessing.image import img_to_array
16
  from huggingface_hub import hf_hub_download
17
  import gradio as gr
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------
20
  # Custom attention layers
21
  # -----------------------------
@@ -48,6 +64,10 @@ class ChannelAttention(layers.Layer):
48
  config.update({'ratio': self.ratio})
49
  return config
50
 
 
 
 
 
51
 
52
  class SpatialAttention(layers.Layer):
53
  def __init__(self, **kwargs):
@@ -64,6 +84,10 @@ class SpatialAttention(layers.Layer):
64
  def get_config(self):
65
  return super(SpatialAttention, self).get_config()
66
 
 
 
 
 
67
 
68
  # -----------------------------
69
  # Load model + tokenizer
@@ -75,7 +99,6 @@ def load_caption_model(model_path):
75
  'SpatialAttention': SpatialAttention
76
  }
77
  model = load_model(model_path, custom_objects=custom_objects)
78
- print("✅ Đã load model thành công!")
79
  return model
80
 
81
 
@@ -99,7 +122,6 @@ def load_feature_extractor():
99
  def extract_features_from_image(image_path, extractor):
100
  image = cv2.imread(image_path)
101
  if image is None:
102
- print(f"❌ Không đọc được ảnh: {image_path}")
103
  return None
104
  image = cv2.resize(image, (224, 224))
105
  image = img_to_array(image)
@@ -131,20 +153,6 @@ def generate_caption(model, tokenizer, image_features, max_length):
131
  # App initialization
132
  # -----------------------------
133
 
134
- MODEL_REPO = "slyviee/img_cap"
135
-
136
- # Khởi tạo tài nguyên toàn cục khi app start
137
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras")
138
- tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl")
139
- config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl")
140
-
141
- model = None
142
- tokenizer = None
143
- max_length = None
144
- vocab_size = None
145
- extractor = None
146
- ready = False
147
- startup_error = ""
148
 
149
  def _startup():
150
  global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error
@@ -188,7 +196,6 @@ def predict(pil_image: Image.Image):
188
 
189
  DESCRIPTION = (
190
  "Upload ảnh và nhận caption sinh ra bởi mô hình. "
191
- "Cần có các tệp: best_model.keras, tokenizer.pkl, model_config.pkl."
192
  )
193
 
194
  demo = gr.Interface(
 
16
  from huggingface_hub import hf_hub_download
17
  import gradio as gr
18
 
19
+
20
+ MODEL_REPO = "slyviee/img_cap"
21
+
22
+ # Khởi tạo tài nguyên toàn cục khi app start
23
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.keras")
24
+ tokenizer_path = hf_hub_download(repo_id=MODEL_REPO, filename="tokenizer.pkl")
25
+ config_path = hf_hub_download(repo_id=MODEL_REPO, filename="model_config.pkl")
26
+
27
+ model = None
28
+ tokenizer = None
29
+ max_length = None
30
+ vocab_size = None
31
+ extractor = None
32
+ ready = False
33
+ startup_error = ""
34
+
35
  # -----------------------------
36
  # Custom attention layers
37
  # -----------------------------
 
64
  config.update({'ratio': self.ratio})
65
  return config
66
 
67
+ @classmethod
68
+ def from_config(cls, config):
69
+ return cls(**config)
70
+
71
 
72
  class SpatialAttention(layers.Layer):
73
  def __init__(self, **kwargs):
 
84
  def get_config(self):
85
  return super(SpatialAttention, self).get_config()
86
 
87
+ @classmethod
88
+ def from_config(cls, config):
89
+ return cls(**config)
90
+
91
 
92
  # -----------------------------
93
  # Load model + tokenizer
 
99
  'SpatialAttention': SpatialAttention
100
  }
101
  model = load_model(model_path, custom_objects=custom_objects)
 
102
  return model
103
 
104
 
 
122
  def extract_features_from_image(image_path, extractor):
123
  image = cv2.imread(image_path)
124
  if image is None:
 
125
  return None
126
  image = cv2.resize(image, (224, 224))
127
  image = img_to_array(image)
 
153
  # App initialization
154
  # -----------------------------
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def _startup():
158
  global model, tokenizer, max_length, vocab_size, extractor, ready, startup_error
 
196
 
197
  DESCRIPTION = (
198
  "Upload ảnh và nhận caption sinh ra bởi mô hình. "
 
199
  )
200
 
201
  demo = gr.Interface(