Gaojunyao commited on
Commit
6b305d4
·
verified ·
1 Parent(s): b46a95b

Delete process.py

Browse files
Files changed (1) hide show
  1. process.py +0 -126
process.py DELETED
@@ -1,126 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import os
4
- import re
5
- from packaging import version as pver
6
-
7
- def custom_meshgrid(*args):
8
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
9
- if pver.parse(torch.__version__) < pver.parse('1.10'):
10
- return torch.meshgrid(*args)
11
- else:
12
- return torch.meshgrid(*args, indexing='ij')
13
-
14
-
15
- def ray_condition(K, c2w, H, W, device, flip_flag=None):
16
- # c2w: B, V, 4, 4
17
- # K: B, V, 4
18
-
19
- B, V = K.shape[:2]
20
-
21
- j, i = custom_meshgrid(
22
- torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
23
- torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
24
- )
25
- i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
26
- j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
27
-
28
- n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
29
- if n_flip > 0:
30
- j_flip, i_flip = custom_meshgrid(
31
- torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
32
- torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
33
- )
34
- i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
35
- j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
36
- i[:, flip_flag, ...] = i_flip
37
- j[:, flip_flag, ...] = j_flip
38
-
39
- fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
40
-
41
- zs = torch.ones_like(i) # [B, V, HxW]
42
- xs = (i - cx) / fx * zs
43
- ys = (j - cy) / fy * zs
44
- zs = zs.expand_as(ys)
45
-
46
- directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
47
- directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
48
-
49
- rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
50
- rays_o = c2w[..., :3, 3] # B, V, 3
51
- rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
52
- # c2w @ dirctions
53
- rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3
54
- plucker = torch.cat([rays_dxo, rays_d], dim=-1)
55
- plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
56
- # plucker = plucker.permute(0, 1, 4, 2, 3)
57
- return plucker
58
-
59
- def plucker_from_camera_params(
60
- K, R, t, W0, H0, cw_np,
61
- H, W,
62
- device='cpu',
63
- normalize_intrinsics=True
64
- ):
65
- """
66
- K, R, t, W0, H0 同 get_camera_params 返回,
67
- cw_np 是 camera.matrix_world 转成的 numpy.ndarray (4×4)。
68
- H, W 是你要的光线网格尺寸。
69
- """
70
- # 1) 归一化 intrinsics
71
- if normalize_intrinsics:
72
- fx = K[0,0] / W0
73
- fy = K[1,1] / H0
74
- cx = K[0,2] / W0
75
- cy = K[1,2] / H0
76
- else:
77
- fx, fy = K[0,0], K[1,1]
78
- cx, cy = K[0,2], K[1,2]
79
-
80
- intrinsics = torch.tensor(
81
- [[[fx, fy, cx, cy]]],
82
- dtype=torch.float32,
83
- device=device
84
- ) # [1,1,4]
85
-
86
- # 2) cw_np 就是 c2w
87
- # 直接转成张量,不要再调用 Matrix 方法
88
- c2w = torch.tensor(cw_np, dtype=torch.float32, device=device)
89
- c2w = c2w.unsqueeze(0).unsqueeze(0) # [1,1,4,4]
90
-
91
- # 3) 调用 ray_condition
92
- flip_flag = torch.zeros(1, dtype=torch.bool, device=device)
93
- plucker = ray_condition(
94
- intrinsics,
95
- c2w,
96
- H, W,
97
- device=device,
98
- flip_flag=flip_flag
99
- ) # [1,1,H,W,6]
100
-
101
- return plucker[0,0] # [6, H, W]
102
-
103
-
104
- dir_path = "1_19_Idle"
105
- dict_files = sorted([os.path.join(dir_path, file) for file in os.listdir(dir_path) if file.endswith(".npy") and "dict" in file])
106
-
107
-
108
- for dict_file in dict_files:
109
- params = np.load(dict_file, allow_pickle=True).item()
110
- match_view = re.search(r'view(\d+)', str(dict_file))
111
- if match_view:
112
- view = int(match_view.group(1))
113
- else:
114
- view = None
115
-
116
- plucker_embedding = plucker_from_camera_params(
117
- params['K'], params['R'], params['t'], params['W0'], params['H0'], params['cw'],
118
- 480, 720,
119
- device='cpu',
120
- normalize_intrinsics=True
121
- )
122
- plucker_embedding = torch.tensor(plucker_embedding).permute(2,0,1)
123
- o = params['cw'][:3, 3]
124
- r = np.linalg.norm(o)
125
- plucker_embedding[:3] = plucker_embedding[:3] / r
126
- torch.save(plucker_embedding, f"{view}.pt")