Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from collections import defaultdict | |
| from typing import List, Dict | |
| import faiss | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ | |
| KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool | |
| from hfutils.operate import get_hf_fs, get_hf_client | |
| from hfutils.utils import TemporaryDirectory | |
| from imgutils.tagging import wd14 | |
| from imgutils.utils import ts_lru_cache | |
| from pools import quick_webp_pool | |
| _REPO_ID = 'deepghs/anime_sites_indices' | |
| hf_fs = get_hf_fs() | |
| hf_client = get_hf_client() | |
| _DEFAULT_MODEL_NAME = 'SwinV2_v3_iqdb_10_46796044_8GB' | |
| _ALL_MODEL_NAMES = [ | |
| os.path.dirname(os.path.relpath(path, _REPO_ID)) | |
| for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') | |
| ] | |
| _SITE_CLS = { | |
| 'danbooru': DanbooruNewestWebpDataPool, | |
| 'yandere': YandeWebpDataPool, | |
| 'zerochan': ZerochanWebpDataPool, | |
| 'gelbooru': GelbooruWebpDataPool, | |
| 'konachan': KonachanWebpDataPool, | |
| 'anime_pictures': AnimePicturesWebpDataPool, | |
| 'rule34': Rule34WebpDataPool, | |
| } | |
| def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: | |
| with TemporaryDirectory() as td: | |
| site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3) | |
| datapool = site_cls() | |
| datapool.batch_download_to_directory( | |
| resource_ids=ids, | |
| dst_dir=td, | |
| ) | |
| retval = {} | |
| for file in os.listdir(td): | |
| id_ = int(os.path.splitext(file)[0]) | |
| image = Image.open(os.path.join(td, file)) | |
| image.load() | |
| retval[id_] = image | |
| return retval | |
| def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: | |
| _sites = defaultdict(list) | |
| for id_ in ids: | |
| site_name, num_id = id_.rsplit('_', maxsplit=1) | |
| num_id = int(num_id) | |
| _sites[site_name].append(num_id) | |
| _retval = {} | |
| for site_name, site_ids in _sites.items(): | |
| _retval.update({ | |
| f'{site_name}_{id_}': image | |
| for id_, image in _get_from_ids(site_name, site_ids).items() | |
| }) | |
| return _retval | |
| def _get_index_info(repo_id: str, model_name: str): | |
| image_ids = np.load(hf_client.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type='model', | |
| filename=f'{model_name}/ids.npy', | |
| )) | |
| knn_index = faiss.read_index(hf_client.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type='model', | |
| filename=f'{model_name}/knn.index', | |
| )) | |
| config = json.loads(open(hf_client.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type='model', | |
| filename=f'{model_name}/infos.json', | |
| )).read())["index_param"] | |
| faiss.ParameterSpace().set_index_parameters(knn_index, config) | |
| return image_ids, knn_index | |
| def search(model_name: str, img_input, n_neighbours: int): | |
| images_ids, knn_index = _get_index_info(_REPO_ID, model_name) | |
| embeddings = wd14.get_wd14_tags( | |
| img_input, | |
| model_name="SwinV2_v3", | |
| fmt="embedding", | |
| ) | |
| embeddings = np.expand_dims(embeddings, 0) | |
| faiss.normalize_L2(embeddings) | |
| dists, indexes = knn_index.search(embeddings, k=n_neighbours) | |
| neighbours_ids = images_ids[indexes][0] | |
| captions = [] | |
| images = [] | |
| ids_to_images = _get_from_raw_ids(neighbours_ids) | |
| for image_id, dist in zip(neighbours_ids, dists[0]): | |
| if image_id in ids_to_images: | |
| images.append(ids_to_images[image_id]) | |
| captions.append(f"{image_id}/{dist:.2f}") | |
| return list(zip(images, captions)) | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Input") | |
| with gr.Column(): | |
| with gr.Row(): | |
| n_model = gr.Dropdown( | |
| choices=_ALL_MODEL_NAMES, | |
| value=_DEFAULT_MODEL_NAME, | |
| label='Index to Use', | |
| ) | |
| with gr.Row(): | |
| n_neighbours = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=20, | |
| step=1, | |
| label="# of images", | |
| ) | |
| find_btn = gr.Button("Find similar images") | |
| with gr.Row(): | |
| similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
| find_btn.click( | |
| fn=search, | |
| inputs=[ | |
| n_model, | |
| img_input, | |
| n_neighbours, | |
| ], | |
| outputs=[similar_images], | |
| ) | |
| demo.queue().launch() | |