File size: 736 Bytes
b678162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Tuple, Dict, Set

import torch
from torch import Tensor
from jaxtyping import Float, Bool, UInt8, Int32


def inside_image(
    pts2d: Float[Tensor, "n 2"],
    image_size: Tuple[int, ...]
) -> Float[Tensor, " n"]:
    H, W = image_size
    px, py = pts2d.unbind(-1)
    return (
        (0 <= px) & (px < W) &
        (0 <= py) & (py < H)
    )


def get_uv_grid(
    image_size: Tuple[int, int],
    dtype=torch.float32
) -> Float[Tensor, "h w 2"]:
    H, W = image_size
    meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy")
    id_coords = torch.stack(meshgrid, dim=-1).to(dtype)
    return id_coords


def persp_project(xyz):
    z = xyz[:, 2:]
    uv = xyz[:, :2] / z
    return uv, z