PHarder's picture
test new
8a814b1
import torch
import kornia
import cv2
import torchvision.transforms as T
def feature_detection_mapping(images, processor, model, device):
inputs = processor(images, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
image_sizes = [[(image.height, image.width) for image in images]]
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
return outputs
def stitch_images(img0_pil, img1_pil, output, device):
to_tensor = T.ToTensor()
image0 = to_tensor(img0_pil).to(device)
image1 = to_tensor(img1_pil).to(device)
pts0 = output["keypoints0"].float()
pts1 = output["keypoints1"].float()
num_matches = pts0.shape[0]
if num_matches < 4:
return None, num_matches
p0_np = pts0.detach().cpu().numpy()
p1_np = pts1.detach().cpu().numpy()
H_np, mask = cv2.findHomography(
p1_np,
p0_np,
method=cv2.USAC_MAGSAC,
ransacReprojThreshold=5.0,
confidence=0.999,
maxIters=100000
)
if H_np is None:
return None, num_matches
H = torch.from_numpy(H_np).to(device).float()
_, h0, w0 = image0.shape
_, h1, w1 = image1.shape
corners1 = torch.tensor([[0., 0.], [float(w1), 0.], [float(w1), float(h1)], [0., float(h1)]], device=device)
corners1_homo = torch.cat([corners1, torch.ones((4, 1), device=device)], dim=1).T
warped_homo = H @ corners1_homo
warped_corners1 = (warped_homo[:2] / warped_homo[2]).T
all_coords = torch.cat([
warped_corners1,
torch.tensor([[0., 0.], [float(w0), 0.], [float(w0), float(h0)], [0., float(h0)]], device=device)
], dim=0)
min_xy = all_coords.min(dim=0).values
max_xy = all_coords.max(dim=0).values
translation = torch.eye(3, device=device)
translation[0, 2] = -min_xy[0]
translation[1, 2] = -min_xy[1]
H_final = translation @ H
out_size = (int(max_xy[1] - min_xy[1]), int(max_xy[0] - min_xy[0]))
warped0 = kornia.geometry.transform.warp_perspective(
image0.unsqueeze(0), translation.unsqueeze(0), dsize=out_size, align_corners=True
).squeeze(0)
warped1 = kornia.geometry.transform.warp_perspective(
image1.unsqueeze(0), H_final.unsqueeze(0), dsize=out_size, align_corners=True
).squeeze(0)
mask0 = (warped0.abs().sum(dim=0, keepdim=True) > 1e-5).float()
mask1 = (warped1.abs().sum(dim=0, keepdim=True) > 1e-5).float()
stitched = (warped0 + warped1) / (mask0 + mask1 + 1e-8)
return stitched, num_matches