chore: cleanup for HF Spaces (ignore pth)
Browse files- .gitignore +1 -1
- MaskClustering/graph/construction.py +164 -0
- MaskClustering/graph/iterative_clustering.py +44 -0
- MaskClustering/graph/node.py +49 -0
- mvp.py +45 -26
.gitignore
CHANGED
|
@@ -153,4 +153,4 @@ temp/
|
|
| 153 |
**/*.glb
|
| 154 |
**/*.bin
|
| 155 |
data/
|
| 156 |
-
**/*.pth
|
|
|
|
| 153 |
**/*.glb
|
| 154 |
**/*.bin
|
| 155 |
data/
|
| 156 |
+
**/*.pth
|
MaskClustering/graph/construction.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from utils.mask_backprojection import frame_backprojection
|
| 5 |
+
from graph.node import Node
|
| 6 |
+
|
| 7 |
+
def mask_graph_construction(args, scene_points, frame_list, dataset):
|
| 8 |
+
'''
|
| 9 |
+
Construct the mask graph:
|
| 10 |
+
1. Build the point in mask matrix. (To speed up the following computation of view consensus rate.)
|
| 11 |
+
2. For each mask, compute the frames that it appears and the masks that contains it. Concurrently, we judge whether this mask is undersegmented.
|
| 12 |
+
3. Build the nodes in the graph.
|
| 13 |
+
'''
|
| 14 |
+
if args.debug:
|
| 15 |
+
print('start building point in mask matrix')
|
| 16 |
+
boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list = build_point_in_mask_matrix(args, scene_points, frame_list, dataset)
|
| 17 |
+
visible_frames, contained_masks, undersegment_mask_ids = process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args)
|
| 18 |
+
observer_num_thresholds = get_observer_num_thresholds(visible_frames)
|
| 19 |
+
nodes = init_nodes(global_frame_mask_list, visible_frames, contained_masks, undersegment_mask_ids, mask_point_clouds)
|
| 20 |
+
return nodes, observer_num_thresholds, mask_point_clouds, point_frame_matrix
|
| 21 |
+
|
| 22 |
+
def build_point_in_mask_matrix(args, scene_points, frame_list, dataset):
|
| 23 |
+
'''
|
| 24 |
+
To speed up the view consensus rate computation, we build a 'point in mask' matrix by a trade-off of space for time. This matrix is of size (scene_points_num, frame_num). For point i and frame j, if point i is in the k-th mask in frame j, then M[i,j] = k. Otherwise, M[i,j] = 0. (Note that mask id starts from 1).
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
boundary_points: a set of points that are contained by multiple masks in a frame and thus are on the boundary of the masks. We will not consider these points in the following computation of view consensus rate.
|
| 28 |
+
point_in_mask_matrix: the 'point in mask' matrix.
|
| 29 |
+
mask_point_clouds: a dict where each key is the mask id in a frame, and the value is the point ids that are in this mask.
|
| 30 |
+
point_frame_matrix: a matrix of size (scene_points_num, frame_num). For point i and frame j, if point i is visible in frame j, then M[i,j] = True. Otherwise, M[i,j] = False.
|
| 31 |
+
global_frame_mask_list: a list of masks in the whole sequence. Each tuple contains the frame id and the mask id in this frame.
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
scene_points_num = len(scene_points)
|
| 35 |
+
frame_num = len(frame_list)
|
| 36 |
+
|
| 37 |
+
scene_points = torch.tensor(scene_points).float().cuda()
|
| 38 |
+
boundary_points = set()
|
| 39 |
+
point_in_mask_matrix = np.zeros((scene_points_num, frame_num), dtype=np.uint16)
|
| 40 |
+
point_frame_matrix = np.zeros((scene_points_num, frame_num), dtype=bool)
|
| 41 |
+
global_frame_mask_list = []
|
| 42 |
+
mask_point_clouds = {}
|
| 43 |
+
|
| 44 |
+
iterator = tqdm(enumerate(frame_list), total=len(frame_list)) if args.debug else enumerate(frame_list)
|
| 45 |
+
for frame_cnt, frame_id in iterator:
|
| 46 |
+
mask_dict, frame_point_cloud_ids = frame_backprojection(dataset, scene_points, frame_id)
|
| 47 |
+
if len(frame_point_cloud_ids) == 0:
|
| 48 |
+
continue
|
| 49 |
+
point_frame_matrix[frame_point_cloud_ids, frame_cnt] = True
|
| 50 |
+
appeared_point_ids = set()
|
| 51 |
+
frame_boundary_point_index = set()
|
| 52 |
+
for mask_id, mask_point_cloud_ids in mask_dict.items():
|
| 53 |
+
frame_boundary_point_index.update(mask_point_cloud_ids.intersection(appeared_point_ids))
|
| 54 |
+
mask_point_clouds[f'{frame_id}_{mask_id}'] = mask_point_cloud_ids
|
| 55 |
+
point_in_mask_matrix[list(mask_point_cloud_ids), frame_cnt] = mask_id
|
| 56 |
+
appeared_point_ids.update(mask_point_cloud_ids)
|
| 57 |
+
global_frame_mask_list.append((frame_id, mask_id))
|
| 58 |
+
point_in_mask_matrix[list(frame_boundary_point_index), frame_cnt] = 0
|
| 59 |
+
boundary_points.update(frame_boundary_point_index)
|
| 60 |
+
|
| 61 |
+
return boundary_points, point_in_mask_matrix, mask_point_clouds, point_frame_matrix, global_frame_mask_list
|
| 62 |
+
|
| 63 |
+
def init_nodes(global_frame_mask_list, mask_project_on_all_frames, contained_masks, undersegment_mask_ids, mask_point_clouds):
|
| 64 |
+
nodes = []
|
| 65 |
+
for global_mask_id, (frame_id, mask_id) in enumerate(global_frame_mask_list):
|
| 66 |
+
if global_mask_id in undersegment_mask_ids:
|
| 67 |
+
continue
|
| 68 |
+
mask_list = [(frame_id, mask_id)]
|
| 69 |
+
frame = mask_project_on_all_frames[global_mask_id]
|
| 70 |
+
frame_mask = contained_masks[global_mask_id]
|
| 71 |
+
point_ids = mask_point_clouds[f'{frame_id}_{mask_id}']
|
| 72 |
+
node_info = (0, len(nodes))
|
| 73 |
+
node = Node(mask_list, frame, frame_mask, point_ids, node_info, None)
|
| 74 |
+
nodes.append(node)
|
| 75 |
+
return nodes
|
| 76 |
+
|
| 77 |
+
def get_observer_num_thresholds(visible_frames):
|
| 78 |
+
'''
|
| 79 |
+
Compute the observer number thresholds for each iteration. Range from 95% to 0%.
|
| 80 |
+
'''
|
| 81 |
+
observer_num_matrix = torch.matmul(visible_frames, visible_frames.transpose(0,1))
|
| 82 |
+
observer_num_list = observer_num_matrix.flatten()
|
| 83 |
+
observer_num_list = observer_num_list[observer_num_list > 0].cpu().numpy()
|
| 84 |
+
observer_num_thresholds = []
|
| 85 |
+
for percentile in range(95, -5, -5):
|
| 86 |
+
observer_num = np.percentile(observer_num_list, percentile)
|
| 87 |
+
if observer_num <= 1:
|
| 88 |
+
if percentile < 50:
|
| 89 |
+
break
|
| 90 |
+
else:
|
| 91 |
+
observer_num = 1
|
| 92 |
+
observer_num_thresholds.append(observer_num)
|
| 93 |
+
return observer_num_thresholds
|
| 94 |
+
|
| 95 |
+
def process_one_mask(point_in_mask_matrix, boundary_points, mask_point_cloud, frame_list, global_frame_mask_list, args):
|
| 96 |
+
'''
|
| 97 |
+
For a mask, compute the frames that it is visible and the masks that contains it.
|
| 98 |
+
'''
|
| 99 |
+
visible_frame = torch.zeros(len(frame_list))
|
| 100 |
+
contained_mask = torch.zeros(len(global_frame_mask_list))
|
| 101 |
+
|
| 102 |
+
valid_mask_point_cloud = mask_point_cloud - boundary_points
|
| 103 |
+
mask_point_cloud_info = point_in_mask_matrix[list(valid_mask_point_cloud), :]
|
| 104 |
+
|
| 105 |
+
possibly_visible_frames = np.where(np.sum(mask_point_cloud_info, axis=0) > 0)[0]
|
| 106 |
+
|
| 107 |
+
split_num = 0
|
| 108 |
+
visible_num = 0
|
| 109 |
+
|
| 110 |
+
for frame_id in possibly_visible_frames:
|
| 111 |
+
mask_id_count = np.bincount(mask_point_cloud_info[:, frame_id])
|
| 112 |
+
invisible_ratio = mask_id_count[0] / np.sum(mask_id_count) # 0 means that this point is invisible in this frame
|
| 113 |
+
# If in a frame, most points in this mask are missing, then we think this mask is invisible in this frame.
|
| 114 |
+
if 1 - invisible_ratio < args.mask_visible_threshold and (np.sum(mask_id_count) - mask_id_count[0]) < 500:
|
| 115 |
+
continue
|
| 116 |
+
visible_num += 1
|
| 117 |
+
mask_id_count[0] = 0
|
| 118 |
+
max_mask_id = np.argmax(mask_id_count)
|
| 119 |
+
contained_ratio = mask_id_count[max_mask_id] / np.sum(mask_id_count)
|
| 120 |
+
if contained_ratio > args.contained_threshold:
|
| 121 |
+
visible_frame[frame_id] = 1
|
| 122 |
+
frame_mask_idx = global_frame_mask_list.index((frame_list[frame_id], max_mask_id))
|
| 123 |
+
contained_mask[frame_mask_idx] = 1
|
| 124 |
+
else:
|
| 125 |
+
split_num += 1 # This mask is splitted into two masks in this frame
|
| 126 |
+
|
| 127 |
+
if visible_num == 0 or split_num / visible_num > args.undersegment_filter_threshold:
|
| 128 |
+
return False, visible_frame, contained_mask
|
| 129 |
+
else:
|
| 130 |
+
return True, visible_frame, contained_mask
|
| 131 |
+
|
| 132 |
+
def process_masks(frame_list, global_frame_mask_list, point_in_mask_matrix, boundary_points, mask_point_clouds, args):
|
| 133 |
+
'''
|
| 134 |
+
For each mask, compute the frames that it is visible and the masks that contains it.
|
| 135 |
+
Meanwhile, we judge whether this mask is undersegmented.
|
| 136 |
+
'''
|
| 137 |
+
if args.debug:
|
| 138 |
+
print('start processing masks')
|
| 139 |
+
visible_frames = []
|
| 140 |
+
contained_masks = []
|
| 141 |
+
undersegment_mask_ids = []
|
| 142 |
+
|
| 143 |
+
iterator = tqdm(global_frame_mask_list) if args.debug else global_frame_mask_list
|
| 144 |
+
for frame_id, mask_id in iterator:
|
| 145 |
+
valid, visible_frame, contained_mask = process_one_mask(point_in_mask_matrix, boundary_points, mask_point_clouds[f'{frame_id}_{mask_id}'], frame_list, global_frame_mask_list, args)
|
| 146 |
+
visible_frames.append(visible_frame)
|
| 147 |
+
contained_masks.append(contained_mask)
|
| 148 |
+
if not valid:
|
| 149 |
+
global_mask_id = global_frame_mask_list.index((frame_id, mask_id))
|
| 150 |
+
undersegment_mask_ids.append(global_mask_id)
|
| 151 |
+
if len(visible_frames) == 0:
|
| 152 |
+
return torch.zeros(0, len(frame_list)).cuda(), torch.zeros(0, len(global_frame_mask_list)).cuda(), undersegment_mask_ids
|
| 153 |
+
visible_frames = torch.stack(visible_frames, dim=0).cuda() # (mask_num, frame_num)
|
| 154 |
+
contained_masks = torch.stack(contained_masks, dim=0).cuda() # (mask_num, mask_num)
|
| 155 |
+
|
| 156 |
+
# Undo the effect of undersegment observer masks to avoid merging two objects that are actually separated
|
| 157 |
+
for global_mask_id in undersegment_mask_ids:
|
| 158 |
+
frame_id, _ = global_frame_mask_list[global_mask_id]
|
| 159 |
+
global_frame_id = frame_list.index(frame_id)
|
| 160 |
+
mask_projected_idx = torch.where(contained_masks[:, global_mask_id])[0]
|
| 161 |
+
contained_masks[:, global_mask_id] = False
|
| 162 |
+
visible_frames[mask_projected_idx, global_frame_id] = False
|
| 163 |
+
|
| 164 |
+
return visible_frames, contained_masks, undersegment_mask_ids
|
MaskClustering/graph/iterative_clustering.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
from graph.node import Node
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def cluster_into_new_nodes(iteration, old_nodes, graph):
|
| 6 |
+
new_nodes = []
|
| 7 |
+
for component in nx.connected_components(graph):
|
| 8 |
+
node_info = (iteration, len(new_nodes))
|
| 9 |
+
new_nodes.append(Node.create_node_from_list([old_nodes[node] for node in component], node_info))
|
| 10 |
+
return new_nodes
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def update_graph(nodes, observer_num_threshold, connect_threshold):
|
| 14 |
+
'''
|
| 15 |
+
update view consensus rates between nodes and return a new graph
|
| 16 |
+
'''
|
| 17 |
+
node_visible_frames = torch.stack([node.visible_frame for node in nodes], dim=0)
|
| 18 |
+
node_contained_masks = torch.stack([node.contained_mask for node in nodes], dim=0)
|
| 19 |
+
|
| 20 |
+
observer_nums = torch.matmul(node_visible_frames, node_visible_frames.transpose(0,1)) # M[i,j] stores the number of frames that node i and node j both appear
|
| 21 |
+
supporter_nums = torch.matmul(node_contained_masks, node_contained_masks.transpose(0,1)) # M[i,j] stores the number of frames that supports the merging of node i and node j
|
| 22 |
+
|
| 23 |
+
view_concensus_rate = supporter_nums / (observer_nums + 1e-7)
|
| 24 |
+
|
| 25 |
+
disconnect = torch.eye(len(nodes), dtype=bool).cuda()
|
| 26 |
+
disconnect = disconnect | (observer_nums < observer_num_threshold) # node pairs with less than observer_num_threshold observers are disconnected
|
| 27 |
+
|
| 28 |
+
A = view_concensus_rate >= connect_threshold
|
| 29 |
+
A = A & ~disconnect
|
| 30 |
+
A = A.cpu().numpy()
|
| 31 |
+
|
| 32 |
+
G = nx.from_numpy_array(A)
|
| 33 |
+
return G
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def iterative_clustering(nodes, observer_num_thresholds, connect_threshold, debug):
|
| 37 |
+
if debug:
|
| 38 |
+
print('====> Start iterative clustering')
|
| 39 |
+
for iterate_id, observer_num_threshold in enumerate(observer_num_thresholds):
|
| 40 |
+
if debug:
|
| 41 |
+
print(f'Iterate {iterate_id}: observer_num', observer_num_threshold, ', number of nodes', len(nodes))
|
| 42 |
+
graph = update_graph(nodes, observer_num_threshold, connect_threshold)
|
| 43 |
+
nodes = cluster_into_new_nodes(iterate_id+1, nodes, graph)
|
| 44 |
+
return nodes
|
MaskClustering/graph/node.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import open3d as o3d
|
| 3 |
+
|
| 4 |
+
class Node:
|
| 5 |
+
|
| 6 |
+
def __init__(self, mask_list, visible_frame, contained_mask, point_ids, node_info, son_node_info):
|
| 7 |
+
'''
|
| 8 |
+
mask_list: list of masks that is within this cluster
|
| 9 |
+
visible_frame: one-hot vector, 1 if the node appears in the frame
|
| 10 |
+
contained_mask: one-hot vector, 1 if the node is contained by the mask
|
| 11 |
+
point_ids: the corresponding 3D point ids
|
| 12 |
+
node_info: for debugging. The iteration and the index of the node in this iteration
|
| 13 |
+
son_node_info: for debugging. Node infos from the last iteration that are merged into this node
|
| 14 |
+
|
| 15 |
+
'''
|
| 16 |
+
self.mask_list = mask_list
|
| 17 |
+
self.visible_frame = visible_frame
|
| 18 |
+
self.contained_mask = contained_mask
|
| 19 |
+
self.point_ids = point_ids
|
| 20 |
+
self.node_info = node_info
|
| 21 |
+
self.son_node_info = son_node_info
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@ staticmethod
|
| 25 |
+
def create_node_from_list(node_list, node_info):
|
| 26 |
+
mask_list = []
|
| 27 |
+
visible_frame = torch.zeros(len(node_list[0].visible_frame), dtype=bool).cuda()
|
| 28 |
+
contained_mask = torch.zeros(len(node_list[0].contained_mask), dtype=bool).cuda()
|
| 29 |
+
point_ids = set()
|
| 30 |
+
son_node_info = set()
|
| 31 |
+
for node in node_list:
|
| 32 |
+
mask_list += node.mask_list
|
| 33 |
+
visible_frame = visible_frame | (node.visible_frame).bool()
|
| 34 |
+
contained_mask = contained_mask | (node.contained_mask).bool()
|
| 35 |
+
point_ids = point_ids.union(node.point_ids)
|
| 36 |
+
son_node_info.add(node.node_info)
|
| 37 |
+
return Node(mask_list, visible_frame.float(), contained_mask.float(), point_ids, node_info, son_node_info)
|
| 38 |
+
|
| 39 |
+
def get_point_cloud(self, scene_points):
|
| 40 |
+
'''
|
| 41 |
+
return:
|
| 42 |
+
pcld: open3d.geometry.PointCloud object, the point cloud of the node
|
| 43 |
+
point_ids: list of int, the corresponding 3D point ids of the node
|
| 44 |
+
'''
|
| 45 |
+
point_ids = list(self.point_ids)
|
| 46 |
+
points = scene_points[point_ids]
|
| 47 |
+
pcld = o3d.geometry.PointCloud()
|
| 48 |
+
pcld.points = o3d.utility.Vector3dVector(points)
|
| 49 |
+
return pcld, point_ids
|
mvp.py
CHANGED
|
@@ -52,23 +52,45 @@ _METRIC3D_MODEL = None
|
|
| 52 |
_CLIP_MODEL = None
|
| 53 |
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
"""
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
"""
|
| 61 |
if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
|
| 62 |
return dst_path
|
| 63 |
|
| 64 |
-
if gdown is None:
|
| 65 |
-
raise RuntimeError("Не найден пакет gdown. Добавь gdown в requirements.txt для загрузки весов из Google Drive.")
|
| 66 |
-
|
| 67 |
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
return dst_path
|
| 73 |
|
| 74 |
|
|
@@ -83,13 +105,13 @@ def _init_models():
|
|
| 83 |
|
| 84 |
if _VGGT_MODEL is None:
|
| 85 |
print("Initializing and loading VGGT model...")
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
m.eval()
|
| 94 |
_VGGT_MODEL = m.to(device)
|
| 95 |
|
|
@@ -118,15 +140,9 @@ cropformer_name = "Mask2Former_hornet_3x_576d0b.pth"
|
|
| 118 |
def check_weights():
|
| 119 |
if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
|
| 120 |
print(f"Downloading {cropformer_name}...")
|
| 121 |
-
# Prefer HF cache over `wget` for Spaces compatibility.
|
| 122 |
-
cached = hf_hub_download(
|
| 123 |
-
repo_id="qqlu1992/Adobe_EntitySeg",
|
| 124 |
-
repo_type="dataset",
|
| 125 |
-
filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
|
| 126 |
-
)
|
| 127 |
os.makedirs(MK_PATH, exist_ok=True)
|
| 128 |
dst = os.path.join(MK_PATH, cropformer_name)
|
| 129 |
-
|
| 130 |
print(f"Downloaded {cropformer_name}...")
|
| 131 |
else:
|
| 132 |
print(f"{cropformer_name} already exists...")
|
|
@@ -222,7 +238,7 @@ def run_model(target_dir, model, metric3d_model=None) -> dict:
|
|
| 222 |
|
| 223 |
# Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
|
| 224 |
# We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
|
| 225 |
-
vggt_depth = predictions["depth"] # (B, H, W, 1) or similar
|
| 226 |
metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
|
| 227 |
|
| 228 |
# Ensure shapes match for broadcasting or direct division
|
|
@@ -251,6 +267,9 @@ def run_model(target_dir, model, metric3d_model=None) -> dict:
|
|
| 251 |
valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
|
| 252 |
|
| 253 |
if valid_mask.sum() > 0:
|
|
|
|
|
|
|
|
|
|
| 254 |
ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
|
| 255 |
scale_factor = torch.median(ratio)
|
| 256 |
print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
|
|
|
|
| 52 |
_CLIP_MODEL = None
|
| 53 |
|
| 54 |
|
| 55 |
+
_MASK2FORMER_GDRIVE_FILE_ID = "10G7s6bVMwN__bcrR2fBal3goo69Y5Do4"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _ensure_mask2former_weights(dst_path: str) -> str:
|
| 59 |
"""
|
| 60 |
+
Ensure Mask2Former/CropFormer weights exist at dst_path.
|
| 61 |
+
Priority:
|
| 62 |
+
1) Use existing file (if present)
|
| 63 |
+
2) Download from Google Drive (user-provided link / file id)
|
| 64 |
+
3) Fallback: download from HF dataset (qqlu1992/Adobe_EntitySeg)
|
| 65 |
"""
|
| 66 |
if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
|
| 67 |
return dst_path
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
| 70 |
+
|
| 71 |
+
# Allow user override via local path
|
| 72 |
+
override_path = os.environ.get("MASK2FORMER_WEIGHTS_PATH")
|
| 73 |
+
if override_path and os.path.exists(override_path) and os.path.getsize(override_path) > 0:
|
| 74 |
+
shutil.copyfile(override_path, dst_path)
|
| 75 |
+
return dst_path
|
| 76 |
+
|
| 77 |
+
# 2) Google Drive
|
| 78 |
+
if gdown is not None:
|
| 79 |
+
url = f"https://drive.google.com/uc?id={_MASK2FORMER_GDRIVE_FILE_ID}"
|
| 80 |
+
out = gdown.download(url, dst_path, quiet=False)
|
| 81 |
+
if out is not None and os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
|
| 82 |
+
return dst_path
|
| 83 |
+
print("Warning: gdown download failed for Mask2Former weights; falling back to HF dataset...")
|
| 84 |
+
else:
|
| 85 |
+
print("Warning: gdown is not available; falling back to HF dataset for Mask2Former weights...")
|
| 86 |
+
|
| 87 |
+
# 3) HF fallback
|
| 88 |
+
cached = hf_hub_download(
|
| 89 |
+
repo_id="qqlu1992/Adobe_EntitySeg",
|
| 90 |
+
repo_type="dataset",
|
| 91 |
+
filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
|
| 92 |
+
)
|
| 93 |
+
shutil.copyfile(cached, dst_path)
|
| 94 |
return dst_path
|
| 95 |
|
| 96 |
|
|
|
|
| 105 |
|
| 106 |
if _VGGT_MODEL is None:
|
| 107 |
print("Initializing and loading VGGT model...")
|
| 108 |
+
# Prefer Hugging Face weights for VGGT
|
| 109 |
+
try:
|
| 110 |
+
m = VGGT.from_pretrained("facebook/VGGT-1B")
|
| 111 |
+
except Exception:
|
| 112 |
+
m = VGGT()
|
| 113 |
+
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
| 114 |
+
m.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
| 115 |
m.eval()
|
| 116 |
_VGGT_MODEL = m.to(device)
|
| 117 |
|
|
|
|
| 140 |
def check_weights():
|
| 141 |
if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
|
| 142 |
print(f"Downloading {cropformer_name}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
os.makedirs(MK_PATH, exist_ok=True)
|
| 144 |
dst = os.path.join(MK_PATH, cropformer_name)
|
| 145 |
+
_ensure_mask2former_weights(dst)
|
| 146 |
print(f"Downloaded {cropformer_name}...")
|
| 147 |
else:
|
| 148 |
print(f"{cropformer_name} already exists...")
|
|
|
|
| 238 |
|
| 239 |
# Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
|
| 240 |
# We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
|
| 241 |
+
vggt_depth = predictions["depth"][0] # (B, H, W, 1) or similar
|
| 242 |
metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
|
| 243 |
|
| 244 |
# Ensure shapes match for broadcasting or direct division
|
|
|
|
| 267 |
valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
|
| 268 |
|
| 269 |
if valid_mask.sum() > 0:
|
| 270 |
+
print(f"Valid mask shape: {valid_mask.shape}")
|
| 271 |
+
print(f"Metric depth shape: {metric_depth.shape}")
|
| 272 |
+
print(f"VGGT depth shape: {vggt_depth.shape}")
|
| 273 |
ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
|
| 274 |
scale_factor = torch.median(ratio)
|
| 275 |
print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
|