| import json |
| import os |
| from collections import defaultdict |
| from typing import List, Dict |
|
|
| import faiss |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ |
| KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool |
| from hfutils.operate import get_hf_fs, get_hf_client |
| from hfutils.utils import TemporaryDirectory |
| from imgutils.tagging import wd14 |
| from imgutils.utils import ts_lru_cache |
|
|
| from pools import quick_webp_pool |
|
|
| _REPO_ID = 'deepghs/anime_sites_indices' |
|
|
| hf_fs = get_hf_fs() |
| hf_client = get_hf_client() |
|
|
| _DEFAULT_MODEL_NAME = 'SwinV2_v3_iqdb_10_46796044_8GB' |
| _ALL_MODEL_NAMES = [ |
| os.path.dirname(os.path.relpath(path, _REPO_ID)) |
| for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') |
| ] |
|
|
| _SITE_CLS = { |
| 'danbooru': DanbooruNewestWebpDataPool, |
| 'yandere': YandeWebpDataPool, |
| 'zerochan': ZerochanWebpDataPool, |
| 'gelbooru': GelbooruWebpDataPool, |
| 'konachan': KonachanWebpDataPool, |
| 'anime_pictures': AnimePicturesWebpDataPool, |
| 'rule34': Rule34WebpDataPool, |
| } |
|
|
|
|
| def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: |
| with TemporaryDirectory() as td: |
| site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3) |
| datapool = site_cls() |
| datapool.batch_download_to_directory( |
| resource_ids=ids, |
| dst_dir=td, |
| ) |
|
|
| retval = {} |
| for file in os.listdir(td): |
| id_ = int(os.path.splitext(file)[0]) |
| image = Image.open(os.path.join(td, file)) |
| image.load() |
| retval[id_] = image |
|
|
| return retval |
|
|
|
|
| def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: |
| _sites = defaultdict(list) |
| for id_ in ids: |
| site_name, num_id = id_.rsplit('_', maxsplit=1) |
| num_id = int(num_id) |
| _sites[site_name].append(num_id) |
|
|
| _retval = {} |
| for site_name, site_ids in _sites.items(): |
| _retval.update({ |
| f'{site_name}_{id_}': image |
| for id_, image in _get_from_ids(site_name, site_ids).items() |
| }) |
| return _retval |
|
|
|
|
| @ts_lru_cache(maxsize=3) |
| def _get_index_info(repo_id: str, model_name: str): |
| image_ids = np.load(hf_client.hf_hub_download( |
| repo_id=repo_id, |
| repo_type='model', |
| filename=f'{model_name}/ids.npy', |
| )) |
| knn_index = faiss.read_index(hf_client.hf_hub_download( |
| repo_id=repo_id, |
| repo_type='model', |
| filename=f'{model_name}/knn.index', |
| )) |
|
|
| config = json.loads(open(hf_client.hf_hub_download( |
| repo_id=repo_id, |
| repo_type='model', |
| filename=f'{model_name}/infos.json', |
| )).read())["index_param"] |
| faiss.ParameterSpace().set_index_parameters(knn_index, config) |
| return image_ids, knn_index |
|
|
|
|
| def search(model_name: str, img_input, n_neighbours: int): |
| images_ids, knn_index = _get_index_info(_REPO_ID, model_name) |
| embeddings = wd14.get_wd14_tags( |
| img_input, |
| model_name="SwinV2_v3", |
| fmt="embedding", |
| ) |
| embeddings = np.expand_dims(embeddings, 0) |
| faiss.normalize_L2(embeddings) |
|
|
| dists, indexes = knn_index.search(embeddings, k=n_neighbours) |
| neighbours_ids = images_ids[indexes][0] |
|
|
| captions = [] |
| images = [] |
| ids_to_images = _get_from_raw_ids(neighbours_ids) |
| for image_id, dist in zip(neighbours_ids, dists[0]): |
| if image_id in ids_to_images: |
| images.append(ids_to_images[image_id]) |
| captions.append(f"{image_id}/{dist:.2f}") |
|
|
| return list(zip(images, captions)) |
|
|
|
|
| if __name__ == "__main__": |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| with gr.Column(): |
| img_input = gr.Image(type="pil", label="Input") |
|
|
| with gr.Column(): |
| with gr.Row(): |
| n_model = gr.Dropdown( |
| choices=_ALL_MODEL_NAMES, |
| value=_DEFAULT_MODEL_NAME, |
| label='Index to Use', |
| ) |
| with gr.Row(): |
| n_neighbours = gr.Slider( |
| minimum=1, |
| maximum=50, |
| value=20, |
| step=1, |
| label="# of images", |
| ) |
| find_btn = gr.Button("Find similar images") |
|
|
| with gr.Row(): |
| similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
| find_btn.click( |
| fn=search, |
| inputs=[ |
| n_model, |
| img_input, |
| n_neighbours, |
| ], |
| outputs=[similar_images], |
| ) |
|
|
| demo.queue().launch() |
|
|