applispee commited on
Commit
40a0441
·
verified ·
1 Parent(s): fb1fe93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -12
app.py CHANGED
@@ -121,21 +121,65 @@ class Predictor:
121
  return
122
 
123
  csv_path, model_path = self.download_model(model_repo)
124
-
 
125
  tags_df = pd.read_csv(csv_path)
126
- sep_tags = load_labels(tags_df)
127
-
128
- self.tag_names = sep_tags[0]
129
- self.rating_indexes = sep_tags[1]
130
- self.general_indexes = sep_tags[2]
131
- self.character_indexes = sep_tags[3]
132
-
133
- model = rt.InferenceSession(model_path)
134
- _, height, width, _ = model.get_inputs()[0].shape
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  self.model_target_size = height
136
 
137
  self.last_loaded_repo = model_repo
138
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def prepare_image(self, image):
141
  target_size = self.model_target_size
@@ -179,7 +223,7 @@ class Predictor:
179
  ):
180
  self.load_model(model_repo)
181
 
182
- image = self.prepare_image(image)
183
 
184
  input_name = self.model.get_inputs()[0].name
185
  label_name = self.model.get_outputs()[0].name
 
121
  return
122
 
123
  csv_path, model_path = self.download_model(model_repo)
124
+
125
+ # タグデータのロード最適化
126
  tags_df = pd.read_csv(csv_path)
127
+ self.tag_names = tags_df["name"].tolist()
128
+ self.tag_names = [x.replace("_", " ") if x not in kaomojis else x for x in self.tag_names]
129
+
130
+ # カテゴリインデックスの効率的な抽出
131
+ categories = tags_df["category"].to_numpy()
132
+ self.rating_indexes = np.where(categories == 9)[0].tolist()
133
+ self.general_indexes = np.where(categories == 0)[0].tolist()
134
+ self.character_indexes = np.where(categories == 4)[0].tolist()
135
+
136
+ # ONNX実行時の最適化オプションを設定
137
+ sess_options = rt.SessionOptions()
138
+ sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
139
+ sess_options.enable_mem_pattern = True
140
+ sess_options.enable_cpu_mem_arena = True
141
+
142
+ # マルチスレッド設定
143
+ sess_options.intra_op_num_threads = 4
144
+ sess_options.inter_op_num_threads = 2
145
+
146
+ # 最適化されたモデルロード
147
+ self.model = rt.InferenceSession(model_path, sess_options)
148
+ _, height, width, _ = self.model.get_inputs()[0].shape
149
  self.model_target_size = height
150
 
151
  self.last_loaded_repo = model_repo
152
+
153
+ def prepare_image_optimized(self, image):
154
+ target_size = self.model_target_size
155
+
156
+ # メモリ効率を高めるためにRGBAからRGBへの変換を最適化
157
+ if image.mode == 'RGBA':
158
+ canvas = Image.new("RGB", image.size, (255, 255, 255))
159
+ canvas.paste(image, mask=image.split()[3])
160
+ image = canvas
161
+ elif image.mode != 'RGB':
162
+ image = image.convert('RGB')
163
+
164
+ # 正方形パディングの最適化
165
+ image_shape = image.size
166
+ max_dim = max(image_shape)
167
+
168
+ # リサイズが必要な場合のみパディングを適用
169
+ if image_shape[0] != image_shape[1]:
170
+ pad_left = (max_dim - image_shape[0]) // 2
171
+ pad_top = (max_dim - image_shape[1]) // 2
172
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
173
+ padded_image.paste(image, (pad_left, pad_top))
174
+ image = padded_image
175
+
176
+ # リサイズの最適化 - 必要な場合のみ実行
177
+ if max_dim != target_size:
178
+ image = image.resize((target_size, target_size), Image.BICUBIC)
179
+
180
+ # NumPy配列への変換とBGR変換を一度に行う
181
+ image_array = np.asarray(image, dtype=np.float32)[:, :, ::-1]
182
+ return np.expand_dims(image_array, axis=0)
183
 
184
  def prepare_image(self, image):
185
  target_size = self.model_target_size
 
223
  ):
224
  self.load_model(model_repo)
225
 
226
+ image = self.prepare_image_optimized(image)
227
 
228
  input_name = self.model.get_inputs()[0].name
229
  label_name = self.model.get_outputs()[0].name