import gradio as gr import os os.system('pip install torch -q') os.system('pip install -U scikit-learn -q') os.system('pip install torchvision -q') os.system('python -m pip install scipy -q') import sys import numpy as np from PIL import Image import torch from codes import * ## Print samples dino_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model="dinov2_vits14") for param in dino_model.parameters(): param.requires_grad = False dino_model.eval() def dino_seg(img1,img2, model_selected): img_path_list = [img1, img2] stackedtokens_, stack_image_batch, grid_size = extract_dino_features(img_path_list,dino_model, smaller_edge_size=448) projections, standard_array = get_projections_and_standardarray(stackedtokens_) masks = get_masks(projections, grid_size,background_threshold = 0.0, apply_opening = False, apply_closing = False) seg_img1 = render_patch_pca3(stackedtokens_[0], standard_array, grid_size) #seg_img1out = Image.fromarray(seg_img1.astype(np.uint8)) seg_img2 = render_patch_pca3(stackedtokens_[1], standard_array, grid_size) return [seg_img1.resize((200, 200), 0),seg_img2.resize((200, 200), 0)] with gr.Blocks() as demo: image_1 = gr.Image( label = "Fixed Image", source = "upload", type = "filepath", elem_id = "image-in", ) image_2 = gr.Image( label = "Moving Image", source = "upload", type = "filepath", elem_id = "image-in", ) model_list = gr.Dropdown( ["small", "small"], label="Model", info="select a model" ) out_image1 = gr.Image(placeholder='Output', label = "seg image", #source = "upload", #type = "filepath", elem_id = "image-out" ) out_image2 = gr.Image(placeholder='Output', label = "seg image", #source = "upload", #type = "filepath", elem_id = "image-out" ) inputs = [image_1, image_2, model_list] outputs = [out_image1, out_image2] iface = gr.Interface(fn=dino_seg, inputs=inputs,outputs=outputs, title="Foreground background seperation", description="Upload 2 images to generate a similarity map:", examples=[["./examples/ex3.jpg","./examples/ex2.png"]], ) demo.queue(default_enabled = True).launch(debug = True)