Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import uuid | |
| from feat_ext import VitLaionFeatureExtractor | |
| import shutil | |
| from queue import Queue, Full | |
| from utils import HFPetDatasetManager, load_enc_cls_model | |
| import os | |
| model_cls = None | |
| feat_extractor = None | |
| processor = None | |
| ds_manager = None | |
| HF_API_TOKEN = os.getenv('HF_API_TOKEN') | |
| ENC_KEY = os.getenv('ENC_KEY') | |
| dataset_name = os.getenv('DATASET_NAME') | |
| ds_manager_queue = Queue(maxsize=1) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def push_files_async(): | |
| try: | |
| ds_manager_queue.put_nowait('Ok') | |
| print('DS upload requested!') | |
| except Full: | |
| print('Pull already started!') | |
| def predict_diff(img_a, img_b): | |
| global model_cls, feat_extractor, processor | |
| x = processor(img_a).unsqueeze(dim=0).to(device), processor(img_b).unsqueeze(dim=0).to(device) | |
| a, b = feat_extractor(x) | |
| proba = torch.sigmoid(model_cls((a, b))).item() | |
| score_str = "{:.2f}".format(round(proba) * proba + round(1 - proba) * (1 - proba)) | |
| base_name = f"{str(uuid.uuid4()).replace('-', '')}-{score_str}" | |
| save_image_pairs(img_a, img_b, proba, base_name) | |
| return {'Same': proba, 'Different': 1 - proba}, base_name | |
| def save_image_pairs(img_a, img_b, proba, base_name): | |
| sub_dir = 'same' if proba > 0.5 else 'different' | |
| img_a.save(f'collected/normal/{sub_dir}/{base_name}_a.png') | |
| img_b.save(f'collected/normal/{sub_dir}/{base_name}_b.png') | |
| push_files_async() | |
| def move_to_flagged(base_name: str, label: str): | |
| sub_dir = label.lower() | |
| destination = f'collected/mistakes/{sub_dir}/' | |
| shutil.move(f'collected/normal/{sub_dir}/{base_name}_a.png', destination) | |
| shutil.move(f'collected/normal/{sub_dir}/{base_name}_b.png', destination) | |
| push_files_async() | |
| class PetFlaggingCallback(gr.FlaggingCallback): | |
| def setup(self, components, flagging_dir: str): | |
| pass | |
| def flag(self, flag_data, flag_option=None, flag_index=None, username=None): | |
| _, _, label, base_name = flag_data | |
| move_to_flagged(base_name, label['label']) | |
| demo = gr.Interface( | |
| title="Dog Recognition", | |
| description="Model that compares two images and identify if the belong to the same or different dog.", | |
| fn=predict_diff, | |
| inputs=[gr.Image(label="Image A", type="pil"), gr.Image(label="Image B", type="pil")], | |
| outputs=["label", gr.Text(visible=False)], | |
| flagging_callback=PetFlaggingCallback() | |
| ) | |
| if __name__ == "__main__": | |
| vit_model = torch.load('vit_model_complete.pt') | |
| vit_processor = torch.load('vit_processor_complete.pt') | |
| model_cls = load_enc_cls_model('model_scripted.pt_enc', ENC_KEY) | |
| feat_extractor = VitLaionFeatureExtractor(vit_model, vit_processor) | |
| processor = feat_extractor.transforms | |
| ds_manager = HFPetDatasetManager(dataset_name, hf_token=HF_API_TOKEN, queue=ds_manager_queue) | |
| ds_manager.daemon = True | |
| ds_manager.start() | |
| model_cls.to(device) | |
| feat_extractor.to(device) | |
| model_cls.eval() | |
| feat_extractor.eval() | |
| demo.queue() | |
| demo.launch() | |