Johnny-Z commited on
Commit
cb64fb3
·
verified ·
1 Parent(s): 1aee7bf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -151,6 +151,9 @@ with open(os.path.join(repo_dir, 'character_threshold.json'), 'r', encoding='utf
151
  with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f:
152
  general_thresholds = json.load(f)
153
 
 
 
 
154
  model_map = MultiheadAttentionPoolingHead(2048)
155
  model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
156
  model_map.to(device).to(dtype).eval()
@@ -253,6 +256,13 @@ def process_image(image):
253
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
254
  character_tags = character_[1]
255
 
 
 
 
 
 
 
 
256
  artist_prediction = mlp_artist(embedding)
257
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
258
  artist_tags = artist_[2]
@@ -266,6 +276,10 @@ def process_image(image):
266
  if tag in implications_list:
267
  for implication in implications_list[tag]:
268
  remove_list.append(implication)
 
 
 
 
269
  tags_list = [tag for tag in tags_list if tag not in remove_list]
270
  tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
271
 
 
151
  with open(os.path.join(repo_dir, 'general_threshold.json'), 'r', encoding='utf-8') as f:
152
  general_thresholds = json.load(f)
153
 
154
+ with open(os.path.join(repo_dir, 'character_feature.json'), 'r', encoding='utf-8') as f:
155
+ character_features = json.load(f)
156
+
157
  model_map = MultiheadAttentionPoolingHead(2048)
158
  model_map.load_state_dict(torch.load(os.path.join(repo_dir, "map_head.pth"), map_location=device, weights_only=True))
159
  model_map.to(device).to(dtype).eval()
 
256
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
257
  character_tags = character_[1]
258
 
259
+ remove_list = []
260
+ for tag in character_tags:
261
+ if tag in implications_list:
262
+ remove_list.extend([implication for implication in implications_list[tag]])
263
+
264
+ character_tags_list = [tag for tag in character_tags if tag not in remove_list]
265
+
266
  artist_prediction = mlp_artist(embedding)
267
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
268
  artist_tags = artist_[2]
 
276
  if tag in implications_list:
277
  for implication in implications_list[tag]:
278
  remove_list.append(implication)
279
+ for char_tag in character_tags_list:
280
+ if char_tag in character_features:
281
+ for character_feature in character_features[char_tag]:
282
+ remove_list.append(character_feature)
283
  tags_list = [tag for tag in tags_list if tag not in remove_list]
284
  tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
285