Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
|
@@ -208,29 +208,6 @@ def prediction_to_tag(prediction, tag_dict, class_num):
|
|
| 208 |
|
| 209 |
return general, character, artist, date, rating
|
| 210 |
|
| 211 |
-
def prediction_to_retrieval(prediction, tag_dict, class_num, top_k):
|
| 212 |
-
prediction = prediction.view(class_num)
|
| 213 |
-
predicted_ids = (prediction>=0.005).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
| 214 |
-
|
| 215 |
-
artist = {}
|
| 216 |
-
date = {}
|
| 217 |
-
|
| 218 |
-
for tag, value in tag_dict.items():
|
| 219 |
-
if value[2] in predicted_ids:
|
| 220 |
-
tag_value = round(prediction[value[2] - 1].item(), 6)
|
| 221 |
-
if value[1] == "artist":
|
| 222 |
-
artist[tag] = tag_value
|
| 223 |
-
elif value[1] == "date":
|
| 224 |
-
date[tag] = tag_value
|
| 225 |
-
|
| 226 |
-
artist = dict(sorted(artist.items(), key=lambda item: item[1], reverse=True))
|
| 227 |
-
artist = dict(list(artist.items())[:top_k])
|
| 228 |
-
|
| 229 |
-
if date:
|
| 230 |
-
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
| 231 |
-
|
| 232 |
-
return artist, date
|
| 233 |
-
|
| 234 |
def process_image(image):
|
| 235 |
try:
|
| 236 |
image = image.convert('RGBA')
|
|
@@ -283,7 +260,7 @@ def process_image(image):
|
|
| 283 |
tags_list = [tag for tag in tags_list if tag not in remove_list]
|
| 284 |
tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
|
| 285 |
|
| 286 |
-
tags_str = ", ".join(tags_list).replace("(", r"\(").replace(")", r"\)")
|
| 287 |
|
| 288 |
return tags_str, artist_tags, character_tags, general_tags, rating, date
|
| 289 |
|
|
|
|
| 208 |
|
| 209 |
return general, character, artist, date, rating
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
def process_image(image):
|
| 212 |
try:
|
| 213 |
image = image.convert('RGBA')
|
|
|
|
| 260 |
tags_list = [tag for tag in tags_list if tag not in remove_list]
|
| 261 |
tags_list = [tag.replace("_", " ") if tag not in kaomojis else tag for tag in tags_list]
|
| 262 |
|
| 263 |
+
tags_str = ", ".join(character_tags_list + tags_list).replace("(", r"\(").replace(")", r"\)")
|
| 264 |
|
| 265 |
return tags_str, artist_tags, character_tags, general_tags, rating, date
|
| 266 |
|