Amould's picture
Update app.py
58bfce0
import gradio as gr
import os
os.system('pip install torch -q')
os.system('pip install -U scikit-learn -q')
os.system('pip install torchvision -q')
os.system('python -m pip install scipy -q')
import sys
import numpy as np
from PIL import Image
import torch
from codes import *
## Print samples
dino_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model="dinov2_vits14")
for param in dino_model.parameters():
param.requires_grad = False
dino_model.eval()
def dino_seg(img1,img2, model_selected):
img_path_list = [img1, img2]
stackedtokens_, stack_image_batch, grid_size = extract_dino_features(img_path_list,dino_model, smaller_edge_size=448)
projections, standard_array = get_projections_and_standardarray(stackedtokens_)
masks = get_masks(projections, grid_size,background_threshold = 0.0, apply_opening = False, apply_closing = False)
seg_img1 = render_patch_pca3(stackedtokens_[0], standard_array, grid_size)
#seg_img1out = Image.fromarray(seg_img1.astype(np.uint8))
seg_img2 = render_patch_pca3(stackedtokens_[1], standard_array, grid_size)
return [seg_img1.resize((200, 200), 0),seg_img2.resize((200, 200), 0)]
with gr.Blocks() as demo:
image_1 = gr.Image(
label = "Fixed Image",
source = "upload",
type = "filepath",
elem_id = "image-in",
)
image_2 = gr.Image(
label = "Moving Image",
source = "upload",
type = "filepath",
elem_id = "image-in",
)
model_list = gr.Dropdown(
["small", "small"], label="Model", info="select a model"
)
out_image1 = gr.Image(placeholder='Output', label = "seg image",
#source = "upload",
#type = "filepath",
elem_id = "image-out"
)
out_image2 = gr.Image(placeholder='Output', label = "seg image",
#source = "upload",
#type = "filepath",
elem_id = "image-out"
)
inputs = [image_1, image_2, model_list]
outputs = [out_image1, out_image2]
iface = gr.Interface(fn=dino_seg, inputs=inputs,outputs=outputs,
title="Foreground background seperation",
description="Upload 2 images to generate a similarity map:",
examples=[["./examples/ex3.jpg","./examples/ex2.png"]],
)
demo.queue(default_enabled = True).launch(debug = True)