Johnny-Z commited on
Commit
d52be25
·
verified ·
1 Parent(s): beefff8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py CHANGED
@@ -127,6 +127,25 @@ class MLP(nn.Module):
127
  x = self.sigmoid(x)
128
  return x
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
131
  general_dict = json.load(f)
132
 
@@ -167,6 +186,10 @@ mlp_artist = MLP(2048, artist_class)
167
  mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
168
  mlp_artist.to(device).to(dtype).eval()
169
 
 
 
 
 
170
  def prediction_to_tag(prediction, tag_dict, class_num):
171
  prediction = prediction.view(class_num)
172
  predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
@@ -202,6 +225,29 @@ def prediction_to_tag(prediction, tag_dict, class_num):
202
 
203
  return general, character, artist, date, rating
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def process_image(image):
206
  try:
207
  image = image.convert('RGBA')
@@ -227,10 +273,17 @@ def process_image(image):
227
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
228
  character_tags = character_[1]
229
 
 
230
  artist_prediction = mlp_artist(embedding)
231
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
232
  artist_tags = artist_[2]
233
  date = artist_[3]
 
 
 
 
 
 
234
 
235
  combined_tags = {**general_tags}
236
 
 
127
  x = self.sigmoid(x)
128
  return x
129
 
130
+ class MLP_Retrieval(nn.Module):
131
+ def __init__(self, input_size, class_num):
132
+ super().__init__()
133
+ self.mlp_layer0 = nn.Sequential(
134
+ nn.Linear(input_size, input_size // 2),
135
+ nn.SiLU()
136
+ )
137
+ self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
138
+
139
+ def forward(self, x):
140
+ x = self.mlp_layer0(x)
141
+ x = self.mlp_layer1(x)
142
+ x1, x2 = x[:, :15], x[:, 15:]
143
+ x1 = torch.softmax(x1, dim=1)
144
+ x2 = torch.softmax(x2, dim=1)
145
+ x = torch.cat([x1, x2], dim=1)
146
+
147
+ return x
148
+
149
  with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
150
  general_dict = json.load(f)
151
 
 
186
  mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
187
  mlp_artist.to(device).to(dtype).eval()
188
 
189
+ mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
190
+ mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
191
+ mlp_artist_retrieval.to(device).to(dtype).eval()
192
+
193
  def prediction_to_tag(prediction, tag_dict, class_num):
194
  prediction = prediction.view(class_num)
195
  predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
 
225
 
226
  return general, character, artist, date, rating
227
 
228
+ def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
229
+ prediction = prediction.view(class_num)
230
+ predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
231
+
232
+ artist = {}
233
+ date = {}
234
+
235
+ for tag, value in tag_dict.items():
236
+ if value[2] in predicted_ids:
237
+ tag_value = round(prediction[value[2] - 1].item(), 6)
238
+ if value[1] == "artist":
239
+ artist[tag] = tag_value
240
+ elif value[1] == "date":
241
+ date[tag] = tag_value
242
+
243
+ artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
244
+ artist = dict(list(artist.items())[:top_k])
245
+
246
+ if date:
247
+ date = {max(date, key=date.get): date[max(date, key=date.get)]}
248
+
249
+ return artist, date
250
+
251
  def process_image(image):
252
  try:
253
  image = image.convert('RGBA')
 
273
  character_ = prediction_to_tag(character_prediction, character_dict, character_class)
274
  character_tags = character_[1]
275
 
276
+ """
277
  artist_prediction = mlp_artist(embedding)
278
  artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
279
  artist_tags = artist_[2]
280
  date = artist_[3]
281
+ """
282
+
283
+ artist_retrieval_prediction = mlp_artist_retrieval(embedding)
284
+ artist_retrieval_ = prediction_to_retrieval(artist_retrieval_prediction, artist_dict, artist_class, 10)
285
+ artist_tags = artist_retrieval_[0]
286
+ date = artist_retrieval_[1]
287
 
288
  combined_tags = {**general_tags}
289