Spaces:
Running
Running
| # coding: utf-8 | |
| # Copyright (C) 2023, [Breezedeus](https://github.com/breezedeus). | |
| # Licensed to the Apache Software Foundation (ASF) under one | |
| # or more contributor license agreements. See the NOTICE file | |
| # distributed with this work for additional information | |
| # regarding copyright ownership. The ASF licenses this file | |
| # to you under the Apache License, Version 2.0 (the | |
| # "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an | |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| # KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations | |
| # under the License. | |
| import os | |
| import sys | |
| import logging | |
| from typing import List | |
| import yaml | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from datasets import load_dataset | |
| import chromadb | |
| from chromadb import Settings | |
| from coin_clip.utils import resize_img | |
| from coin_clip.chroma_embedding import ChromaEmbeddingFunction | |
| from coin_clip.detect import Detector | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| env = os.environ.get('COIN_ENV', 'local') | |
| if env == 'hf': | |
| config_fp = 'hf_config.yaml' | |
| hf_token = os.environ.get('HF_TOKEN') | |
| else: | |
| config_fp = 'local_config.yaml' | |
| logger.info(f'Use config file: {config_fp}') | |
| total_config = yaml.safe_load(open(config_fp)) | |
| DETECTOR = Detector( | |
| model_name=total_config['detector']['model_name'], | |
| device=total_config['detector']['device'], | |
| ) | |
| # USE_REMOVE_BG = total_config['use_remove_bg'] | |
| RESIZED_TO_BEFORE_DETECT = total_config['detector'].get('resized_to', 300) | |
| def prepare_chromadb(): | |
| if env == 'local': | |
| return | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_type='model', | |
| repo_id='breezedeus/usa-coins-chromadb', | |
| local_dir='./', | |
| token=hf_token, | |
| ) | |
| def _load_dataset(data_path): | |
| logger.info('Load dataset from %s', data_path) | |
| if env == 'hf': | |
| dataset = load_dataset(data_path, split='train', token=hf_token) | |
| else: | |
| dataset = load_dataset("imagefolder", data_dir=data_path, split='train') | |
| return dataset | |
| def detect(images): | |
| outs = [] | |
| for idx, img in enumerate(images): | |
| img = resize_img(img, RESIZED_TO_BEFORE_DETECT) | |
| out = DETECTOR.detect(np.array(img)) | |
| if not out: | |
| out = {'position': None, 'scores': 0.0} | |
| else: | |
| out = out[0] | |
| out.pop('label') | |
| out['position'] = out.pop('box') | |
| out['from_image_idx'] = idx | |
| outs.append(out) | |
| box_images = [] | |
| for out, img in zip(outs, images): | |
| if out['position'] is None: | |
| box_images.append(None) | |
| else: | |
| # box 比例值转化为绝对位置值 | |
| w, h = img.size | |
| box = out['position'] | |
| box = (int(box[0] * w), int(box[1] * h), int(box[2] * w), int(box[3] * h)) | |
| box_images.append(img.crop(box)) | |
| return outs, box_images | |
| def load_chroma_db(db_dir, collection_name, model_name, device='cpu'): | |
| logger.info('Load chroma db from %s', db_dir) | |
| client = chromadb.PersistentClient( | |
| path=db_dir, settings=Settings(anonymized_telemetry=False) | |
| ) | |
| embedding_function = ChromaEmbeddingFunction(model_name, device) | |
| collection = client.get_collection( | |
| name=collection_name, embedding_function=embedding_function, | |
| ) | |
| return collection | |
| def retrieve(query_image: Image.Image, collection, top_k=20) -> List[Image.Image]: | |
| query_image = np.array(query_image) | |
| retrieved = collection.query( | |
| query_images=[query_image], include=['metadatas', 'distances'], n_results=top_k, | |
| ) | |
| logger.info('retrieved ids: %s', retrieved['ids'][0]) | |
| logger.info('retrieved distances: %s', retrieved['distances'][0]) | |
| return [ds_dict[id]['image'] for id in retrieved['ids'][0]] | |
| dataset = _load_dataset(**total_config['dataset']) | |
| ds_dict = {_d['id']: _d for _d in dataset} | |
| prepare_chromadb() | |
| cc_collection = load_chroma_db(**total_config['coin_clip_db']) | |
| clip_collection = load_chroma_db(**total_config['clip_db']) | |
| def search(image_file: Image.Image): | |
| images = [image_file.convert('RGB')] | |
| detected_outs, box_images = detect(images) | |
| box_images = [img for img in box_images if img is not None] | |
| if len(box_images) == 0: | |
| return [ | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ] | |
| box_image = box_images[0] | |
| # breakpoint() | |
| cc_results = retrieve(box_image, cc_collection, top_k=30) | |
| clip_results = retrieve(box_image, clip_collection, top_k=30) | |
| return [ | |
| gr.update(value=box_image, visible=True), | |
| gr.update(visible=False), | |
| gr.update(value=cc_results, visible=True), | |
| gr.update(value=clip_results, visible=True), | |
| ] | |
| def main(): | |
| title = 'USA Coin Retrieval by' | |
| # desc = ( | |
| # '<p style="text-align: center">Coin-CLIP: ' | |
| # '<a href="https://huggingface.co/breezedeus/coin-clip-vit-base-patch32" target="_blank">Model</a>, ' | |
| # '<a href="https://github.com/breezedeus/coin-clip" target="_blank">Github</a>; ' | |
| # 'Author: <a href="https://www.breezedeus.com" target="_blank">Breezedeus</a> , ' | |
| # '<a href="https://github.com/breezedeus" target="_blank">Github</a> </p>' | |
| # ) | |
| desc = """ | |
| <div align="center"> | |
| <img src="https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F9341931a-53f0-48e1-b026-0f1ad17b457c%2F003ffb61-964f-4e1a-bfc1-6a8516fc90ac%2FUntitled.png?table=block&id=553ac3c2-1f88-450a-b06a-8ebb4001b29f" width="120px"/> | |
| [](https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbreezedeus%2FUSA-Coin-Retrieval) | |
| | | | | |
| | ------------------------------- | --------------------------------------- | | |
| | 🪙 **Model** | [Coin-CLIP](https://huggingface.co/breezedeus/coin-clip-vit-base-patch32) | | |
| | 💼 **Code** | [Github](https://github.com/breezedeus/coin-clip) | | |
| | 👨🏻💻 **Author** | [Breezedeus](https://www.breezedeus.com) | | |
| | 💬 **Questions** | [GitHub Discussions](https://github.com/breezedeus/coin-clip/issues) | | |
| <br/> | |
| Leave a star 🌟 on the Github [Coin-CLIP 🪙](https://github.com/breezedeus/coin-clip) . | |
| If you're interested in retrieving coins from other countries, please leave a comment on the Github. | |
| </div> | |
| """ | |
| examples = [ | |
| 'examples/c2.jpeg', | |
| 'examples/c20.jpg', | |
| 'examples/c21.jpg', | |
| 'examples/c22.png', | |
| 'examples/c1.jpg', | |
| 'examples/c11.jpg', | |
| 'examples/c3.png', | |
| 'examples/c4.jpg', | |
| 'examples/c5.jpeg', | |
| 'examples/c6.jpeg', | |
| 'examples/c7.jpg', | |
| 'examples/c8.jpeg', | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| f'<h1 style="text-align: center; margin-bottom: 1rem;">{title} <a href="https://github.com/breezedeus/coin-clip" target="_blank">Coin-CLIP</a></h1>' | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant='compact', scale=3): | |
| # gr.HTML('<img src="examples/coin-clip-logo.jpg" width="150px"/>') | |
| gr.Markdown(desc) | |
| with gr.Column(variant='compact', scale=7): | |
| gr.Markdown('### Image within a coin') | |
| image_file = gr.Image( | |
| label='Coin Image to Search', | |
| type="pil", | |
| image_mode='RGB', | |
| height=400, | |
| ) | |
| sub_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(variant='compact', scale=4): | |
| gr.Markdown('### Detected Coin') | |
| detected_image = gr.Image( | |
| label='Detected Coin', | |
| type="pil", | |
| interactive=False, | |
| image_mode='RGB', | |
| height=400, | |
| ) | |
| no_detect_warn = gr.Markdown( | |
| '**⚠️ Warning**: No coins detected in image', visible=False | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant='compact', scale=1): | |
| gr.Markdown('### Results from Coin-CLIP') | |
| cc_results = gr.Gallery( | |
| label='Coin-CLIP Results', | |
| columns=3, | |
| height=2200, | |
| show_share_button=True, | |
| visible=False, | |
| ) | |
| with gr.Column(variant='compact', scale=1): | |
| gr.Markdown('### Results from CLIP') | |
| coin_results = gr.Gallery( | |
| label='CLIP Results', | |
| columns=3, | |
| height=2200, | |
| show_share_button=True, | |
| visible=False, | |
| ) | |
| sub_btn.click( | |
| search, | |
| inputs=[image_file,], | |
| outputs=[detected_image, no_detect_warn, cc_results, coin_results], | |
| ) | |
| gr.Examples( | |
| label='Examples', | |
| examples=examples, | |
| inputs=image_file, | |
| outputs=[detected_image, no_detect_warn, cc_results, coin_results], | |
| fn=search, | |
| examples_per_page=12, | |
| cache_examples=True, | |
| ) | |
| demo.queue(max_size=20) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| main() | |