import torch torch.set_grad_enabled(False) import gradio as gr from datetime import datetime from inference_tagger_standalone import * # from huggingface_hub import hf_hub_download # hf_hub_download(repo_id="lodestones/tagger-experiment", filename="tagger_proto.safetensors", local_dir=".") import os os.system("wget -nv https://huggingface.co/lodestones/tagger-experiment/resolve/main/tagger_proto.safetensors") model = Tagger(checkpoint_path="./tagger_proto.safetensors", vocab_path="./tagger_vocab.json", max_size=1024) def get_tags(image, threshold, top_k): current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") print(f"{current_datetime}: started.") results = model.predict(image, topk=top_k, threshold=threshold) temp = [] return_dict = dict() for rank, (tag, score) in enumerate(results, 1): return_dict[tag] = score temp.append(tag.replace(" ", "_")) return_str = " ".join(temp) current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") print(f"{current_datetime}: finished.\n") return return_str, return_dict demo = gr.Interface( get_tags, inputs=[ gr.Image(label="Source", sources=['upload',], type='filepath'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold"), gr.Slider(minimum=0, maximum=500, step=1, value=30, label="Top K") ], outputs=[ gr.Textbox(label="Tag String"), gr.Label(label="Tag Predictions", num_top_classes=200), ], ) demo.launch(ssr_mode=False)