Spaces:
Build error
Build error
| import gradio as gr | |
| import gradio.components as grc | |
| import onnxruntime | |
| import numpy as np | |
| from torchvision.transforms import Normalize, Compose, Resize, ToTensor | |
| batch_size = 1 | |
| def convert_to_rgb(image): | |
| return image.convert("RGB") | |
| def get_transform(image_size=384): | |
| return Compose([ | |
| convert_to_rgb, | |
| Resize((image_size, image_size)), | |
| ToTensor(), | |
| Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def load_tag_list(tag_list_file): | |
| with open(tag_list_file, 'r', encoding="utf-8") as f: | |
| tag_list = f.read().splitlines() | |
| tag_list = np.array(tag_list) | |
| return tag_list | |
| def load_word_vocabulary(word_vocabulary_file): | |
| with open(word_vocabulary_file, 'r', encoding="utf-8") as f: | |
| word_vocabulary = f.read().splitlines() | |
| words = [word.split(',') for word in word_vocabulary] | |
| word2idx = {} | |
| for i in range(len(words)): | |
| for j in range(len(words[i])): | |
| word2idx[words[i][j]] = i | |
| return word2idx | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download(repo_id="Inf009/ram-tagger", repo_type="model", local_dir="resources", filename="ram_swin_large_14m_b1.onnx", local_dir_use_symlinks=False) | |
| ort_session = onnxruntime.InferenceSession("resources/ram_swin_large_14m_b1.onnx", providers=["CUDAExecutionProvider"]) | |
| transform = get_transform() | |
| tag_list = load_tag_list("resources/ram_tag_list.txt") | |
| word_index = load_word_vocabulary("resources/word_vocabulary_english.txt") | |
| def inference_by_image_pil(image): | |
| image_arrays = transform(image).unsqueeze(0).numpy() | |
| # compute ONNX Runtime output prediction | |
| ort_inputs = {ort_session.get_inputs()[0].name: image_arrays} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| index = np.argwhere(ort_outs[0][0] == 1) | |
| token = tag_list[index].squeeze(axis=1).tolist() | |
| token = rerank_tags(token) | |
| return ",".join(token) | |
| def rerank_tags(tags): | |
| indexed_tags = [[] for _ in range(max(word_index.values()) + 1)] | |
| for tag in tags: | |
| indexed_tags[word_index[tag]].append(tag) | |
| reranked_tags = [] | |
| for indexed_tag in indexed_tags: | |
| reranked_tags += indexed_tag | |
| return reranked_tags | |
| app = gr.Interface(fn=inference_by_image_pil, inputs=grc.Image(type='pil'), | |
| outputs=grc.Text(), title="RAM Tagger", | |
| description="A tagger for images.") | |
| app.launch() |