Spaces:
Running
Running
| #!/usr/bin/env -S uv run | |
| # /// script | |
| # requires-python = "<= 3.12" | |
| # dependencies = [ | |
| # "torchvision", | |
| # "huggingface_hub", | |
| # "timm", | |
| # "opencv-python", | |
| # "mediapipe", | |
| # "timm", | |
| # ] | |
| # /// | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from huggingface_hub import hf_hub_download | |
| import cv2 | |
| from utils import align_crop | |
| from timmfrv2 import TimmFRWrapperV2, model_configs | |
| model_name = sys.argv[1] | |
| image_1 = sys.argv[2] | |
| image_2 = sys.argv[3] | |
| def load_and_crop(image_filename_1, image_filename_2): | |
| img_1 = cv2.imread(image_filename_1) | |
| img_2 = cv2.imread(image_filename_2) | |
| crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR) | |
| crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR) | |
| return crop_1, crop_2 | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Download model") | |
| model_path = hf_hub_download( | |
| repo_id=model_configs[model_name]["repo"], | |
| filename=model_configs[model_name]["filename"], | |
| local_dir="models", | |
| ) | |
| print(f"Model downloaded in {model_path}") | |
| print("Create model") | |
| model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False) | |
| model = model_configs[model_name]["post_setup"](model) | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model = model.eval() | |
| model.to(device) | |
| crop_a, crop_b = load_and_crop(image_1, image_2) | |
| with torch.no_grad(): | |
| ea = model(transform(crop_a)[None].to(device))[0][None] | |
| eb = model(transform(crop_b)[None].to(device))[0][None] | |
| pct = float(F.cosine_similarity(ea, eb).item() * 100) | |
| pct = max(0, min(100, pct)) | |
| print(f"{pct:.2f}% match") | |
| # cv2.imshow("crop left", crop_a) | |
| # cv2.imshow("crop right", crop_b) | |
| # cv2.waitKey(0) | |
| # cv2.destroyAllWindows() | |