Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -127,6 +127,25 @@ class MLP(nn.Module):
|
|
| 127 |
x = self.sigmoid(x)
|
| 128 |
return x
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
|
| 131 |
general_dict = json.load(f)
|
| 132 |
|
|
@@ -167,6 +186,10 @@ mlp_artist = MLP(2048, artist_class)
|
|
| 167 |
mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
|
| 168 |
mlp_artist.to(device).to(dtype).eval()
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def prediction_to_tag(prediction, tag_dict, class_num):
|
| 171 |
prediction = prediction.view(class_num)
|
| 172 |
predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
|
@@ -202,6 +225,29 @@ def prediction_to_tag(prediction, tag_dict, class_num):
|
|
| 202 |
|
| 203 |
return general, character, artist, date, rating
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def process_image(image):
|
| 206 |
try:
|
| 207 |
image = image.convert('RGBA')
|
|
@@ -227,10 +273,17 @@ def process_image(image):
|
|
| 227 |
character_ = prediction_to_tag(character_prediction, character_dict, character_class)
|
| 228 |
character_tags = character_[1]
|
| 229 |
|
|
|
|
| 230 |
artist_prediction = mlp_artist(embedding)
|
| 231 |
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
|
| 232 |
artist_tags = artist_[2]
|
| 233 |
date = artist_[3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
combined_tags = {**general_tags}
|
| 236 |
|
|
|
|
| 127 |
x = self.sigmoid(x)
|
| 128 |
return x
|
| 129 |
|
| 130 |
+
class MLP_Retrieval(nn.Module):
|
| 131 |
+
def __init__(self, input_size, class_num):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.mlp_layer0 = nn.Sequential(
|
| 134 |
+
nn.Linear(input_size, input_size // 2),
|
| 135 |
+
nn.SiLU()
|
| 136 |
+
)
|
| 137 |
+
self.mlp_layer1 = nn.Linear(input_size // 2, class_num)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.mlp_layer0(x)
|
| 141 |
+
x = self.mlp_layer1(x)
|
| 142 |
+
x1, x2 = x[:, :15], x[:, 15:]
|
| 143 |
+
x1 = torch.softmax(x1, dim=1)
|
| 144 |
+
x2 = torch.softmax(x2, dim=1)
|
| 145 |
+
x = torch.cat([x1, x2], dim=1)
|
| 146 |
+
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
with open(os.path.join(repo_dir, 'general_tag_dict.json'), 'r', encoding='utf-8') as f:
|
| 150 |
general_dict = json.load(f)
|
| 151 |
|
|
|
|
| 186 |
mlp_artist.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist.pth"), map_location=device, weights_only=True))
|
| 187 |
mlp_artist.to(device).to(dtype).eval()
|
| 188 |
|
| 189 |
+
mlp_artist_retrieval = MLP_Retrieval(2048, artist_class)
|
| 190 |
+
mlp_artist_retrieval.load_state_dict(torch.load(os.path.join(repo_dir, "cls_predictor_artist_retrieval.pth"), map_location=device, weights_only=True))
|
| 191 |
+
mlp_artist_retrieval.to(device).to(dtype).eval()
|
| 192 |
+
|
| 193 |
def prediction_to_tag(prediction, tag_dict, class_num):
|
| 194 |
prediction = prediction.view(class_num)
|
| 195 |
predicted_ids = (prediction >= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
|
|
|
| 225 |
|
| 226 |
return general, character, artist, date, rating
|
| 227 |
|
| 228 |
+
def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
|
| 229 |
+
prediction = prediction.view(class_num)
|
| 230 |
+
predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
| 231 |
+
|
| 232 |
+
artist = {}
|
| 233 |
+
date = {}
|
| 234 |
+
|
| 235 |
+
for tag, value in tag_dict.items():
|
| 236 |
+
if value[2] in predicted_ids:
|
| 237 |
+
tag_value = round(prediction[value[2] - 1].item(), 6)
|
| 238 |
+
if value[1] == "artist":
|
| 239 |
+
artist[tag] = tag_value
|
| 240 |
+
elif value[1] == "date":
|
| 241 |
+
date[tag] = tag_value
|
| 242 |
+
|
| 243 |
+
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
|
| 244 |
+
artist = dict(list(artist.items())[:top_k])
|
| 245 |
+
|
| 246 |
+
if date:
|
| 247 |
+
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
| 248 |
+
|
| 249 |
+
return artist, date
|
| 250 |
+
|
| 251 |
def process_image(image):
|
| 252 |
try:
|
| 253 |
image = image.convert('RGBA')
|
|
|
|
| 273 |
character_ = prediction_to_tag(character_prediction, character_dict, character_class)
|
| 274 |
character_tags = character_[1]
|
| 275 |
|
| 276 |
+
"""
|
| 277 |
artist_prediction = mlp_artist(embedding)
|
| 278 |
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class)
|
| 279 |
artist_tags = artist_[2]
|
| 280 |
date = artist_[3]
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
artist_retrieval_prediction = mlp_artist_retrieval(embedding)
|
| 284 |
+
artist_retrieval_ = prediction_to_retrieval(artist_retrieval_prediction, artist_dict, artist_class, 10)
|
| 285 |
+
artist_tags = artist_retrieval_[0]
|
| 286 |
+
date = artist_retrieval_[1]
|
| 287 |
|
| 288 |
combined_tags = {**general_tags}
|
| 289 |
|