bulatko commited on
Commit
daaac94
·
1 Parent(s): 6c099d4

chore: cleanup for HF Spaces (ignore pth)

Browse files
.gitignore CHANGED
@@ -153,4 +153,4 @@ temp/
153
  **/*.glb
154
  **/*.bin
155
  data/
156
- **/*.pth
 
153
  **/*.glb
154
  **/*.bin
155
  data/
156
+ **/*.pth
MaskClustering/graph/construction.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from utils.mask_backprojection import frame_backprojection
5
+ from graph.node import Node
6
+
7
+ def mask_graph_construction(args, scene_points, frame_list, dataset):
8
+ '''
9
+ Construct the mask graph:
10
+ 1. Build the point in mask matrix. (To speed up the following computation of view consensus rate.)
11
+ 2. For each mask, compute the frames that it appears and the masks that contains it. Concurrently, we judge whether this mask is undersegmented.
12
+ 3. Build the nodes in the graph.
13
+ '''
14
+ if args.debug:
15
+ print('start building point in mask matrix')
16
+ boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list = build_point_in_mask_matrix(args, scene_points, frame_list, dataset)
17
+ visible_frames, contained_masks, undersegment_mask_ids = process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args)
18
+ observer_num_thresholds = get_observer_num_thresholds(visible_frames)
19
+ nodes = init_nodes(global_frame_mask_list, visible_frames, contained_masks, undersegment_mask_ids, mask_point_clouds)
20
+ return nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix
21
+
22
+ def build_point_in_mask_matrix(args, scene_points, frame_list, dataset):
23
+ '''
24
+ To speed up the view consensus rate computation, we build a 'point in mask' matrix by a trade-off of space for time. This matrix is of size (scene_points_num, frame_num). For point i and frame j, if point i is in the k-th mask in frame j, then M[i,j] = k. Otherwise, M[i,j] = 0. (Note that mask id starts from 1).
25
+
26
+ Returns:
27
+ boundary_points: a set of points that are contained by multiple masks in a frame and thus are on the boundary of the masks. We will not consider these points in the following computation of view consensus rate.
28
+ point_in_mask_matrix: the 'point in mask' matrix.
29
+ mask_point_clouds: a dict where each key is the mask id in a frame, and the value is the point ids that are in this mask.
30
+ point_frame_matrix: a matrix of size (scene_points_num, frame_num). For point i and frame j, if point i is visible in frame j, then M[i,j] = True. Otherwise, M[i,j] = False.
31
+ global_frame_mask_list: a list of masks in the whole sequence. Each tuple contains the frame id and the mask id in this frame.
32
+ '''
33
+
34
+ scene_points_num = len(scene_points)
35
+ frame_num = len(frame_list)
36
+
37
+ scene_points = torch.tensor(scene_points).float().cuda()
38
+ boundary_points = set()
39
+ point_in_mask_matrix = np.zeros((scene_points_num, frame_num), dtype=np.uint16)
40
+ point_frame_matrix = np.zeros((scene_points_num, frame_num), dtype=bool)
41
+ global_frame_mask_list = []
42
+ mask_point_clouds = {}
43
+
44
+ iterator = tqdm(enumerate(frame_list), total=len(frame_list)) if args.debug else enumerate(frame_list)
45
+ for frame_cnt, frame_id in iterator:
46
+ mask_dict, frame_point_cloud_ids = frame_backprojection(dataset, scene_points, frame_id)
47
+ if len(frame_point_cloud_ids) == 0:
48
+ continue
49
+ point_frame_matrix[frame_point_cloud_ids, frame_cnt] = True
50
+ appeared_point_ids = set()
51
+ frame_boundary_point_index = set()
52
+ for mask_id, mask_point_cloud_ids in mask_dict.items():
53
+ frame_boundary_point_index.update(mask_point_cloud_ids.intersection(appeared_point_ids))
54
+ mask_point_clouds[f'{frame_id}_{mask_id}'] = mask_point_cloud_ids
55
+ point_in_mask_matrix[list(mask_point_cloud_ids), frame_cnt] = mask_id
56
+ appeared_point_ids.update(mask_point_cloud_ids)
57
+ global_frame_mask_list.append((frame_id, mask_id))
58
+ point_in_mask_matrix[list(frame_boundary_point_index), frame_cnt] = 0
59
+ boundary_points.update(frame_boundary_point_index)
60
+
61
+ return boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list
62
+
63
+ def init_nodes(global_frame_mask_list, mask_project_on_all_frames, contained_masks, undersegment_mask_ids, mask_point_clouds):
64
+ nodes = []
65
+ for global_mask_id, (frame_id, mask_id) in enumerate(global_frame_mask_list):
66
+ if global_mask_id in undersegment_mask_ids:
67
+ continue
68
+ mask_list = [(frame_id, mask_id)]
69
+ frame = mask_project_on_all_frames[global_mask_id]
70
+ frame_mask = contained_masks[global_mask_id]
71
+ point_ids = mask_point_clouds[f'{frame_id}_{mask_id}']
72
+ node_info = (0, len(nodes))
73
+ node = Node(mask_list, frame, frame_mask, point_ids, node_info, None)
74
+ nodes.append(node)
75
+ return nodes
76
+
77
+ def get_observer_num_thresholds(visible_frames):
78
+ '''
79
+ Compute the observer number thresholds for each iteration. Range from 95% to 0%.
80
+ '''
81
+ observer_num_matrix = torch.matmul(visible_frames, visible_frames.transpose(0,1))
82
+ observer_num_list = observer_num_matrix.flatten()
83
+ observer_num_list = observer_num_list[observer_num_list > 0].cpu().numpy()
84
+ observer_num_thresholds = []
85
+ for percentile in range(95, -5, -5):
86
+ observer_num = np.percentile(observer_num_list, percentile)
87
+ if observer_num <= 1:
88
+ if percentile < 50:
89
+ break
90
+ else:
91
+ observer_num = 1
92
+ observer_num_thresholds.append(observer_num)
93
+ return observer_num_thresholds
94
+
95
+ def process_one_mask(point_in_mask_matrix, boundary_points, mask_point_cloud, frame_list, global_frame_mask_list, args):
96
+ '''
97
+ For a mask, compute the frames that it is visible and the masks that contains it.
98
+ '''
99
+ visible_frame = torch.zeros(len(frame_list))
100
+ contained_mask = torch.zeros(len(global_frame_mask_list))
101
+
102
+ valid_mask_point_cloud = mask_point_cloud - boundary_points
103
+ mask_point_cloud_info = point_in_mask_matrix[list(valid_mask_point_cloud), :]
104
+
105
+ possibly_visible_frames = np.where(np.sum(mask_point_cloud_info, axis=0) > 0)[0]
106
+
107
+ split_num = 0
108
+ visible_num = 0
109
+
110
+ for frame_id in possibly_visible_frames:
111
+ mask_id_count = np.bincount(mask_point_cloud_info[:, frame_id])
112
+ invisible_ratio = mask_id_count[0] / np.sum(mask_id_count) # 0 means that this point is invisible in this frame
113
+ # If in a frame, most points in this mask are missing, then we think this mask is invisible in this frame.
114
+ if 1 - invisible_ratio < args.mask_visible_threshold and (np.sum(mask_id_count) - mask_id_count[0]) < 500:
115
+ continue
116
+ visible_num += 1
117
+ mask_id_count[0] = 0
118
+ max_mask_id = np.argmax(mask_id_count)
119
+ contained_ratio = mask_id_count[max_mask_id] / np.sum(mask_id_count)
120
+ if contained_ratio > args.contained_threshold:
121
+ visible_frame[frame_id] = 1
122
+ frame_mask_idx = global_frame_mask_list.index((frame_list[frame_id], max_mask_id))
123
+ contained_mask[frame_mask_idx] = 1
124
+ else:
125
+ split_num += 1 # This mask is splitted into two masks in this frame
126
+
127
+ if visible_num == 0 or split_num / visible_num > args.undersegment_filter_threshold:
128
+ return False, visible_frame, contained_mask
129
+ else:
130
+ return True, visible_frame, contained_mask
131
+
132
+ def process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args):
133
+ '''
134
+ For each mask, compute the frames that it is visible and the masks that contains it.
135
+ Meanwhile, we judge whether this mask is undersegmented.
136
+ '''
137
+ if args.debug:
138
+ print('start processing masks')
139
+ visible_frames = []
140
+ contained_masks = []
141
+ undersegment_mask_ids = []
142
+
143
+ iterator = tqdm(global_frame_mask_list) if args.debug else global_frame_mask_list
144
+ for frame_id, mask_id in iterator:
145
+ valid, visible_frame, contained_mask = process_one_mask(point_in_mask_matrix, boundary_points, mask_point_clouds[f'{frame_id}_{mask_id}'], frame_list, global_frame_mask_list, args)
146
+ visible_frames.append(visible_frame)
147
+ contained_masks.append(contained_mask)
148
+ if not valid:
149
+ global_mask_id = global_frame_mask_list.index((frame_id, mask_id))
150
+ undersegment_mask_ids.append(global_mask_id)
151
+ if len(visible_frames) == 0:
152
+ return torch.zeros(0, len(frame_list)).cuda(), torch.zeros(0, len(global_frame_mask_list)).cuda(), undersegment_mask_ids
153
+ visible_frames = torch.stack(visible_frames, dim=0).cuda() # (mask_num, frame_num)
154
+ contained_masks = torch.stack(contained_masks, dim=0).cuda() # (mask_num, mask_num)
155
+
156
+ # Undo the effect of undersegment observer masks to avoid merging two objects that are actually separated
157
+ for global_mask_id in undersegment_mask_ids:
158
+ frame_id, _ = global_frame_mask_list[global_mask_id]
159
+ global_frame_id = frame_list.index(frame_id)
160
+ mask_projected_idx = torch.where(contained_masks[:, global_mask_id])[0]
161
+ contained_masks[:, global_mask_id] = False
162
+ visible_frames[mask_projected_idx, global_frame_id] = False
163
+
164
+ return visible_frames, contained_masks, undersegment_mask_ids
MaskClustering/graph/iterative_clustering.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ from graph.node import Node
3
+ import torch
4
+
5
+ def cluster_into_new_nodes(iteration, old_nodes, graph):
6
+ new_nodes = []
7
+ for component in nx.connected_components(graph):
8
+ node_info = (iteration, len(new_nodes))
9
+ new_nodes.append(Node.create_node_from_list([old_nodes[node] for node in component], node_info))
10
+ return new_nodes
11
+
12
+
13
+ def update_graph(nodes, observer_num_threshold, connect_threshold):
14
+ '''
15
+ update view consensus rates between nodes and return a new graph
16
+ '''
17
+ node_visible_frames = torch.stack([node.visible_frame for node in nodes], dim=0)
18
+ node_contained_masks = torch.stack([node.contained_mask for node in nodes], dim=0)
19
+
20
+ observer_nums = torch.matmul(node_visible_frames, node_visible_frames.transpose(0,1)) # M[i,j] stores the number of frames that node i and node j both appear
21
+ supporter_nums = torch.matmul(node_contained_masks, node_contained_masks.transpose(0,1)) # M[i,j] stores the number of frames that supports the merging of node i and node j
22
+
23
+ view_concensus_rate = supporter_nums / (observer_nums + 1e-7)
24
+
25
+ disconnect = torch.eye(len(nodes), dtype=bool).cuda()
26
+ disconnect = disconnect | (observer_nums < observer_num_threshold) # node pairs with less than observer_num_threshold observers are disconnected
27
+
28
+ A = view_concensus_rate >= connect_threshold
29
+ A = A & ~disconnect
30
+ A = A.cpu().numpy()
31
+
32
+ G = nx.from_numpy_array(A)
33
+ return G
34
+
35
+
36
+ def iterative_clustering(nodes, observer_num_thresholds, connect_threshold, debug):
37
+ if debug:
38
+ print('====> Start iterative clustering')
39
+ for iterate_id, observer_num_threshold in enumerate(observer_num_thresholds):
40
+ if debug:
41
+ print(f'Iterate {iterate_id}: observer_num', observer_num_threshold, ', number of nodes', len(nodes))
42
+ graph = update_graph(nodes, observer_num_threshold, connect_threshold)
43
+ nodes = cluster_into_new_nodes(iterate_id+1, nodes, graph)
44
+ return nodes
MaskClustering/graph/node.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open3d as o3d
3
+
4
+ class Node:
5
+
6
+ def __init__(self, mask_list, visible_frame, contained_mask, point_ids, node_info, son_node_info):
7
+ '''
8
+ mask_list: list of masks that is within this cluster
9
+ visible_frame: one-hot vector, 1 if the node appears in the frame
10
+ contained_mask: one-hot vector, 1 if the node is contained by the mask
11
+ point_ids: the corresponding 3D point ids
12
+ node_info: for debugging. The iteration and the index of the node in this iteration
13
+ son_node_info: for debugging. Node infos from the last iteration that are merged into this node
14
+
15
+ '''
16
+ self.mask_list = mask_list
17
+ self.visible_frame = visible_frame
18
+ self.contained_mask = contained_mask
19
+ self.point_ids = point_ids
20
+ self.node_info = node_info
21
+ self.son_node_info = son_node_info
22
+
23
+
24
+ @ staticmethod
25
+ def create_node_from_list(node_list, node_info):
26
+ mask_list = []
27
+ visible_frame = torch.zeros(len(node_list[0].visible_frame), dtype=bool).cuda()
28
+ contained_mask = torch.zeros(len(node_list[0].contained_mask), dtype=bool).cuda()
29
+ point_ids = set()
30
+ son_node_info = set()
31
+ for node in node_list:
32
+ mask_list += node.mask_list
33
+ visible_frame = visible_frame | (node.visible_frame).bool()
34
+ contained_mask = contained_mask | (node.contained_mask).bool()
35
+ point_ids = point_ids.union(node.point_ids)
36
+ son_node_info.add(node.node_info)
37
+ return Node(mask_list, visible_frame.float(), contained_mask.float(), point_ids, node_info, son_node_info)
38
+
39
+ def get_point_cloud(self, scene_points):
40
+ '''
41
+ return:
42
+ pcld: open3d.geometry.PointCloud object, the point cloud of the node
43
+ point_ids: list of int, the corresponding 3D point ids of the node
44
+ '''
45
+ point_ids = list(self.point_ids)
46
+ points = scene_points[point_ids]
47
+ pcld = o3d.geometry.PointCloud()
48
+ pcld.points = o3d.utility.Vector3dVector(points)
49
+ return pcld, point_ids
mvp.py CHANGED
@@ -52,23 +52,45 @@ _METRIC3D_MODEL = None
52
  _CLIP_MODEL = None
53
 
54
 
55
- def _download_vggt_weights(dst_path: str) -> str:
 
 
 
56
  """
57
- Download VGGT weights from Google Drive to dst_path.
58
- The user provided:
59
- https://drive.google.com/file/d/10G7s6bVMwN__bcrR2fBal3goo69Y5Do4/view?usp=sharing
 
 
60
  """
61
  if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
62
  return dst_path
63
 
64
- if gdown is None:
65
- raise RuntimeError("Не найден пакет gdown. Добавь gdown в requirements.txt для загрузки весов из Google Drive.")
66
-
67
  os.makedirs(os.path.dirname(dst_path), exist_ok=True)
68
- url = "https://drive.google.com/uc?id=10G7s6bVMwN__bcrR2fBal3goo69Y5Do4"
69
- out = gdown.download(url, dst_path, quiet=False)
70
- if out is None or not os.path.exists(dst_path) or os.path.getsize(dst_path) == 0:
71
- raise RuntimeError("Не удалось скачать веса VGGT из Google Drive (проверь доступ/квоты/публичность).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return dst_path
73
 
74
 
@@ -83,13 +105,13 @@ def _init_models():
83
 
84
  if _VGGT_MODEL is None:
85
  print("Initializing and loading VGGT model...")
86
- m = VGGT()
87
- weights_path = os.environ.get("VGGT_WEIGHTS_PATH")
88
- if not weights_path:
89
- weights_path = os.path.join(WORK_DIR, "weights", "vggt_model.pt")
90
- _download_vggt_weights(weights_path)
91
- state = torch.load(weights_path, map_location="cpu")
92
- m.load_state_dict(state)
93
  m.eval()
94
  _VGGT_MODEL = m.to(device)
95
 
@@ -118,15 +140,9 @@ cropformer_name = "Mask2Former_hornet_3x_576d0b.pth"
118
  def check_weights():
119
  if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
120
  print(f"Downloading {cropformer_name}...")
121
- # Prefer HF cache over `wget` for Spaces compatibility.
122
- cached = hf_hub_download(
123
- repo_id="qqlu1992/Adobe_EntitySeg",
124
- repo_type="dataset",
125
- filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
126
- )
127
  os.makedirs(MK_PATH, exist_ok=True)
128
  dst = os.path.join(MK_PATH, cropformer_name)
129
- shutil.copyfile(cached, dst)
130
  print(f"Downloaded {cropformer_name}...")
131
  else:
132
  print(f"{cropformer_name} already exists...")
@@ -222,7 +238,7 @@ def run_model(target_dir, model, metric3d_model=None) -> dict:
222
 
223
  # Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
224
  # We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
225
- vggt_depth = predictions["depth"] # (B, H, W, 1) or similar
226
  metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
227
 
228
  # Ensure shapes match for broadcasting or direct division
@@ -251,6 +267,9 @@ def run_model(target_dir, model, metric3d_model=None) -> dict:
251
  valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
252
 
253
  if valid_mask.sum() > 0:
 
 
 
254
  ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
255
  scale_factor = torch.median(ratio)
256
  print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
 
52
  _CLIP_MODEL = None
53
 
54
 
55
+ _MASK2FORMER_GDRIVE_FILE_ID = "10G7s6bVMwN__bcrR2fBal3goo69Y5Do4"
56
+
57
+
58
+ def _ensure_mask2former_weights(dst_path: str) -> str:
59
  """
60
+ Ensure Mask2Former/CropFormer weights exist at dst_path.
61
+ Priority:
62
+ 1) Use existing file (if present)
63
+ 2) Download from Google Drive (user-provided link / file id)
64
+ 3) Fallback: download from HF dataset (qqlu1992/Adobe_EntitySeg)
65
  """
66
  if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
67
  return dst_path
68
 
 
 
 
69
  os.makedirs(os.path.dirname(dst_path), exist_ok=True)
70
+
71
+ # Allow user override via local path
72
+ override_path = os.environ.get("MASK2FORMER_WEIGHTS_PATH")
73
+ if override_path and os.path.exists(override_path) and os.path.getsize(override_path) > 0:
74
+ shutil.copyfile(override_path, dst_path)
75
+ return dst_path
76
+
77
+ # 2) Google Drive
78
+ if gdown is not None:
79
+ url = f"https://drive.google.com/uc?id={_MASK2FORMER_GDRIVE_FILE_ID}"
80
+ out = gdown.download(url, dst_path, quiet=False)
81
+ if out is not None and os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
82
+ return dst_path
83
+ print("Warning: gdown download failed for Mask2Former weights; falling back to HF dataset...")
84
+ else:
85
+ print("Warning: gdown is not available; falling back to HF dataset for Mask2Former weights...")
86
+
87
+ # 3) HF fallback
88
+ cached = hf_hub_download(
89
+ repo_id="qqlu1992/Adobe_EntitySeg",
90
+ repo_type="dataset",
91
+ filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
92
+ )
93
+ shutil.copyfile(cached, dst_path)
94
  return dst_path
95
 
96
 
 
105
 
106
  if _VGGT_MODEL is None:
107
  print("Initializing and loading VGGT model...")
108
+ # Prefer Hugging Face weights for VGGT
109
+ try:
110
+ m = VGGT.from_pretrained("facebook/VGGT-1B")
111
+ except Exception:
112
+ m = VGGT()
113
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
114
+ m.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
115
  m.eval()
116
  _VGGT_MODEL = m.to(device)
117
 
 
140
  def check_weights():
141
  if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
142
  print(f"Downloading {cropformer_name}...")
 
 
 
 
 
 
143
  os.makedirs(MK_PATH, exist_ok=True)
144
  dst = os.path.join(MK_PATH, cropformer_name)
145
+ _ensure_mask2former_weights(dst)
146
  print(f"Downloaded {cropformer_name}...")
147
  else:
148
  print(f"{cropformer_name} already exists...")
 
238
 
239
  # Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
240
  # We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
241
+ vggt_depth = predictions["depth"][0] # (B, H, W, 1) or similar
242
  metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
243
 
244
  # Ensure shapes match for broadcasting or direct division
 
267
  valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
268
 
269
  if valid_mask.sum() > 0:
270
+ print(f"Valid mask shape: {valid_mask.shape}")
271
+ print(f"Metric depth shape: {metric_depth.shape}")
272
+ print(f"VGGT depth shape: {vggt_depth.shape}")
273
  ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
274
  scale_factor = torch.median(ratio)
275
  print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")