| import json |
| from functools import lru_cache |
|
|
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| from autofaiss import build_index |
| from hfutils.operate import get_hf_fs |
| from huggingface_hub import hf_hub_download |
| from imgutils.data import load_image |
| from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold |
|
|
| SRC_REPO = 'deepghs/character_index' |
|
|
| hf_fs = get_hf_fs() |
|
|
|
|
| @lru_cache() |
| def _make_index(): |
| tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json'))) |
| embeddings = np.load(hf_hub_download( |
| repo_id=SRC_REPO, |
| repo_type='dataset', |
| filename='index/embeddings.npy', |
| )) |
| index, index_infos = build_index(embeddings, save_on_disk=False) |
| return (index, index_infos), tag_infos |
|
|
|
|
| def gender_predict(p): |
| if p['boy'] - p['girl'] >= 0.1: |
| return 'male' |
| elif p['girl'] - p['boy'] >= 0.1: |
| return 'female' |
| else: |
| return 'not_sure' |
|
|
|
|
| def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7): |
| (index, index_infos), tag_infos = _make_index() |
| query = ccip_batch_extract_features([image]) |
| assert query.shape == (1, 768) |
| query = query / np.linalg.norm(query) |
| all_dists, all_indices = index.search(query, k=count) |
| dists, indices = all_dists[0], all_indices[0] |
|
|
| images, records = {}, [] |
| for dist, idx in zip(dists, indices): |
| info = tag_infos[idx] |
| current_image = load_image(hf_hub_download( |
| repo_id=SRC_REPO, |
| repo_type='dataset', |
| filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp' |
| )) |
| feats = np.load(hf_hub_download( |
| repo_id=SRC_REPO, |
| repo_type='dataset', |
| filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy' |
| )) |
| diffs = ccip_batch_differences([query[0], *feats])[0, 1:] |
| images[info['tag']] = current_image |
| records.append({ |
| 'id': info['id'], |
| 'tag': info['tag'], |
| 'gender': gender_predict(info['gender']), |
| 'copyright': info['copyright'], |
| 'index_score': dist, |
| 'mean_diff': diffs.mean(), |
| 'same_ratio': (diffs < ccip_default_threshold()).mean(), |
| }) |
|
|
| df_records = pd.DataFrame(records) |
| df_records = df_records.sort_values( |
| by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'], |
| ascending=[False, False] if order_by != 'index_score' else [False], |
| ) |
| df_records = df_records[df_records[order_by] >= threshold] |
| ret_images = [] |
| for row_item in df_records.to_dict('records'): |
| ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})')) |
| return ret_images, df_records |
|
|