|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import timm |
|
|
import timm.layers.ml_decoder |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
import torchvision |
|
|
from torchvision import transforms |
|
|
import PIL |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import json |
|
|
|
|
|
headers = { |
|
|
"User-Agent": "Gradio 0-shot classification demo", |
|
|
} |
|
|
|
|
|
TITLE = "Danbooru 0-shot classifiction demo" |
|
|
DESCRIPTION = """ |
|
|
Demo for 0-shot classification on Danbooru images. |
|
|
Davit-tiny backbone, ML-Decoder classification head, Alibaba-NLP/gte-large-en-v1.5 text embedding model |
|
|
""" |
|
|
|
|
|
def scrape_img(postID): |
|
|
postURL = "https://danbooru.donmai.us/posts/" + str(postID) + ".json" |
|
|
postData = json.loads(requests.get(postURL, headers=headers).content) |
|
|
imageURL = postData['file_url'] |
|
|
|
|
|
print("Getting image from " + imageURL) |
|
|
response = requests.get(imageURL, headers=headers) |
|
|
image = Image.open(BytesIO(response.content)) |
|
|
image.load() |
|
|
return image |
|
|
|
|
|
def scrape_wiki(tagName): |
|
|
wikiHistoryURL = f"https://danbooru.donmai.us/wiki_page_versions.json?search[title]={tagName}" |
|
|
wikiHistory = json.loads(requests.get(wikiHistoryURL, headers=headers).content) |
|
|
wikiBody = (": " + wikiHistory[0]['body'] if len(wikiHistory) > 0 else "") |
|
|
return tagName + wikiBody |
|
|
|
|
|
class Predictor: |
|
|
def __init__(self): |
|
|
self.img_size = (288, 288) |
|
|
self.cls_model = None |
|
|
self.tokenizer = None |
|
|
self.text_emb_model = None |
|
|
self.class_embed = None |
|
|
self.tag_names = None |
|
|
|
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
with open('tags1588.pkl', 'rb') as f: |
|
|
classes = pickle.load(f) |
|
|
tagNames = classes[0].to_list() |
|
|
self.tag_names = tagNames |
|
|
|
|
|
pretrained_weights = torch.load('model.pth', map_location=torch.device('cpu')) |
|
|
self.class_embed = pretrained_weights['0.head.head.class_embed.weight'] |
|
|
cls_model = timm.create_model('davit_tiny', num_classes=len(classes)) |
|
|
|
|
|
cls_model = timm.layers.ml_decoder.add_ml_decoder_head( |
|
|
cls_model, |
|
|
num_groups=len(classes), |
|
|
class_embed=class_embed, |
|
|
class_embed_merge='', |
|
|
shared_fc=True) |
|
|
cls_model = nn.Sequential(cls_model) |
|
|
cls_model.load_state_dict(pretrained_weights, strict=True) |
|
|
cls_model = cls_model.eval() |
|
|
|
|
|
self.cls_model = model |
|
|
|
|
|
model_path = 'Alibaba-NLP/gte-large-en-v1.5' |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.text_emb_model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
|
|
self.text_emb_model = self.text_emb_model.eval() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_text(self, input_strings): |
|
|
with torch.no_grad(): |
|
|
|
|
|
embeddingList = [] |
|
|
for text in input_strings: |
|
|
batch_dict = self.tokenizer(text, padding=True, truncation=False, return_tensors='pt') |
|
|
outputs = self.text_emb_model(**batch_dict.to(self.text_emb_model.device)) |
|
|
embeddings = outputs.last_hidden_state[:, 0] |
|
|
embeddingList.append(embeddings.cpu()) |
|
|
embeddings = torch.cat(embeddingList) |
|
|
return embeddings |
|
|
|
|
|
def get_tag_embed(self, tag): |
|
|
wiki_text = scrape_wiki(tag) |
|
|
tag_embed = embed_text([wiki_text]) |
|
|
return wiki_text, tag_embed |
|
|
|
|
|
def prepare_image(self, image): |
|
|
image.load() |
|
|
image = image.convert("RGBA") |
|
|
|
|
|
color = (255,255,255) |
|
|
|
|
|
background = Image.new('RGB', image.size, color) |
|
|
background.paste(image, mask=image.split()[3]) |
|
|
image = background |
|
|
image = transforms.Resize(img_size, interpolation = torchvision.transforms.InterpolationMode.BICUBIC)(image) |
|
|
image = transforms.ToTensor()(image) |
|
|
return image |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
image, |
|
|
query = self.class_embed, |
|
|
tag_names = self.tag_names, |
|
|
): |
|
|
image = self.prepare_image(image) |
|
|
|
|
|
image_features = cls_model[0].forward_features(image.unsqueeze(0)) |
|
|
outputs = cls_model[0].forward_head(image_features, q = query).sigmoid().float() |
|
|
|
|
|
general_tag_list = list(zip(tagNames, outputs[0].tolist())) |
|
|
general_tag_list.sort(key=lambda y: y[1], reverse=True) |
|
|
general_tag_preds_dict = {} |
|
|
for tag, prob in general_tag_list[:50]: |
|
|
general_tag_preds_dict[tag] = prob |
|
|
|
|
|
return general_tag_preds_dict |
|
|
|
|
|
def predict_new_tag( |
|
|
self, |
|
|
image, |
|
|
query, |
|
|
): |
|
|
return self.predict(image, query=query, tag_names=["embedding"])["embedding"] |
|
|
|
|
|
|
|
|
def main(): |
|
|
predictor = Predictor() |
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
|
with gr.Column(): |
|
|
gr.Markdown( |
|
|
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" |
|
|
) |
|
|
gr.Markdown(value=DESCRIPTION) |
|
|
with gr.Row(): |
|
|
with gr.Column(variant="panel"): |
|
|
image = gr.Image(type="pil", image_mode="RGBA", label="Input") |
|
|
with gr.Row(): |
|
|
post_id = gr.Textbox(label="Post ID") |
|
|
get_post = gr.Button(value="Get Post", variant="primary", size="lg") |
|
|
with gr.Row(): |
|
|
clear = gr.ClearButton( |
|
|
value="Clear image", |
|
|
components=[ |
|
|
image, |
|
|
], |
|
|
variant="secondary", |
|
|
size="lg", |
|
|
) |
|
|
submit = gr.Button(value="Submit", variant="primary", size="lg") |
|
|
with gr.Column(variant="panel"): |
|
|
with gr.Row(): |
|
|
tag_name = gr.Textbox(label="Tag Name") |
|
|
clear_tag_data = gr.ClearButton(value="Clear tag", variant="secondary", size="lg") |
|
|
get_tag_description = gr.Button(value="Get tag description", variant="primary", size="lg") |
|
|
tag_description = gr.Textbox(label="Tag description") |
|
|
with gr.Row(): |
|
|
predict_on_description = gr.Button(value="Predict described tag:") |
|
|
description_prediction = gr.Textbox(label="Probability") |
|
|
general_bars = gr.Label(label="Known tags") |
|
|
clear.add( |
|
|
[ |
|
|
general_bars, |
|
|
description_prediction, |
|
|
post_id, |
|
|
] |
|
|
) |
|
|
clear_tag_data.add( |
|
|
[ |
|
|
tag_description, |
|
|
tag_name, |
|
|
description_prediction, |
|
|
] |
|
|
) |
|
|
|
|
|
submit.click( |
|
|
predictor.predict, |
|
|
inputs=[ |
|
|
image, |
|
|
], |
|
|
outputs=[general_bars], |
|
|
) |
|
|
predict_on_description.click( |
|
|
predictor.predict_new_tag, |
|
|
inputs=[image, tag_description], |
|
|
outputs=[description_prediction] |
|
|
) |
|
|
get_post.click( |
|
|
scrape_img, |
|
|
inputs=[post_id], |
|
|
outputs=[image] |
|
|
) |
|
|
get_tag_description.click( |
|
|
scrape_wiki, |
|
|
inputs=[tag_name], |
|
|
outputs=[tag_description] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=10) |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |