Spaces:
Running
Running
| import torch | |
| torch.set_grad_enabled(False) | |
| import gradio as gr | |
| from datetime import datetime | |
| from inference_tagger_standalone import * | |
| # from huggingface_hub import hf_hub_download | |
| # hf_hub_download(repo_id="lodestones/tagger-experiment", filename="tagger_proto.safetensors", local_dir=".") | |
| import os | |
| os.system("wget -nv https://huggingface.co/lodestones/tagger-experiment/resolve/main/tagger_proto.safetensors") | |
| model = Tagger(checkpoint_path="./tagger_proto.safetensors", vocab_path="./tagger_vocab.json", max_size=1024) | |
| def get_tags(image, threshold, top_k): | |
| current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| print(f"{current_datetime}: started.") | |
| results = model.predict(image, topk=top_k, threshold=threshold) | |
| temp = [] | |
| return_dict = dict() | |
| for rank, (tag, score) in enumerate(results, 1): | |
| return_dict[tag] = score | |
| temp.append(tag.replace(" ", "_")) | |
| return_str = " ".join(temp) | |
| current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| print(f"{current_datetime}: finished.\n") | |
| return return_str, return_dict | |
| demo = gr.Interface( | |
| get_tags, | |
| inputs=[ | |
| gr.Image(label="Source", sources=['upload',], type='filepath'), | |
| gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold"), | |
| gr.Slider(minimum=0, maximum=500, step=1, value=30, label="Top K") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Tag String"), | |
| gr.Label(label="Tag Predictions", num_top_classes=200), | |
| ], | |
| ) | |
| demo.launch(ssr_mode=False) | |