|
|
import tqdm |
|
|
import torch |
|
|
from dust3r.utils.device import to_cpu, collate_with_cat |
|
|
from dust3r.utils.misc import invalid_to_nans |
|
|
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf |
|
|
from dust3r.model import ARCroco3DStereo |
|
|
from accelerate import Accelerator |
|
|
import re |
|
|
|
|
|
|
|
|
def custom_sort_key(key): |
|
|
text = key.split("/") |
|
|
if len(text) > 1: |
|
|
text, num = text[0], text[-1] |
|
|
return (text, int(num)) |
|
|
else: |
|
|
return (key, -1) |
|
|
|
|
|
|
|
|
def merge_chunk_dict(old_dict, curr_dict, add_number): |
|
|
new_dict = {} |
|
|
for key, value in curr_dict.items(): |
|
|
|
|
|
match = re.search(r"(\d+)$", key) |
|
|
if match: |
|
|
|
|
|
num_part = int(match.group()) + add_number |
|
|
|
|
|
new_key = re.sub(r"(\d+)$", str(num_part), key, 1) |
|
|
new_dict[new_key] = value |
|
|
else: |
|
|
new_dict[key] = value |
|
|
new_dict = old_dict | new_dict |
|
|
return {k: new_dict[k] for k in sorted(new_dict.keys(), key=custom_sort_key)} |
|
|
|
|
|
|
|
|
def _interleave_imgs(img1, img2): |
|
|
res = {} |
|
|
for key, value1 in img1.items(): |
|
|
value2 = img2[key] |
|
|
if isinstance(value1, torch.Tensor): |
|
|
value = torch.stack((value1, value2), dim=1).flatten(0, 1) |
|
|
else: |
|
|
value = [x for pair in zip(value1, value2) for x in pair] |
|
|
res[key] = value |
|
|
return res |
|
|
|
|
|
|
|
|
def make_batch_symmetric(batch): |
|
|
view1, view2 = batch |
|
|
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) |
|
|
return view1, view2 |
|
|
|
|
|
|
|
|
def loss_of_one_batch( |
|
|
batch, |
|
|
model, |
|
|
criterion, |
|
|
accelerator: Accelerator, |
|
|
symmetrize_batch=False, |
|
|
use_amp=False, |
|
|
ret=None, |
|
|
img_mask=None, |
|
|
inference=False, |
|
|
): |
|
|
if len(batch) > 2: |
|
|
assert ( |
|
|
symmetrize_batch is False |
|
|
), "cannot symmetrize batch with more than 2 views" |
|
|
if symmetrize_batch: |
|
|
batch = make_batch_symmetric(batch) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=not inference): |
|
|
if inference: |
|
|
output, state_args = model(batch, ret_state=True) |
|
|
preds, batch = output.ress, output.views |
|
|
result = dict(views=batch, pred=preds) |
|
|
return result[ret] if ret else result, state_args |
|
|
else: |
|
|
output = model(batch) |
|
|
preds, batch = output.ress, output.views |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
loss = criterion(batch, preds) if criterion is not None else None |
|
|
|
|
|
result = dict(views=batch, pred=preds, loss=loss) |
|
|
return result[ret] if ret else result |
|
|
|
|
|
|
|
|
def loss_of_one_batch_tbptt( |
|
|
batch, |
|
|
model, |
|
|
criterion, |
|
|
chunk_size, |
|
|
loss_scaler, |
|
|
optimizer, |
|
|
accelerator: Accelerator, |
|
|
log_writer=None, |
|
|
symmetrize_batch=False, |
|
|
use_amp=False, |
|
|
ret=None, |
|
|
img_mask=None, |
|
|
inference=False, |
|
|
): |
|
|
if len(batch) > 2: |
|
|
assert ( |
|
|
symmetrize_batch is False |
|
|
), "cannot symmetrize batch with more than 2 views" |
|
|
if symmetrize_batch: |
|
|
batch = make_batch_symmetric(batch) |
|
|
all_preds = [] |
|
|
all_loss = 0.0 |
|
|
all_loss_details = {} |
|
|
with torch.cuda.amp.autocast(enabled=not inference): |
|
|
with torch.no_grad(): |
|
|
(feat, pos, shape), ( |
|
|
init_state_feat, |
|
|
init_mem, |
|
|
state_feat, |
|
|
state_pos, |
|
|
mem, |
|
|
) = accelerator.unwrap_model(model)._forward_encoder(batch) |
|
|
feat = [f.detach() for f in feat] |
|
|
pos = [p.detach() for p in pos] |
|
|
shape = [s.detach() for s in shape] |
|
|
init_state_feat = init_state_feat.detach() |
|
|
init_mem = init_mem.detach() |
|
|
|
|
|
for chunk_id in range((len(batch) - 1) // chunk_size + 1): |
|
|
preds = [] |
|
|
chunk = [] |
|
|
state_feat = state_feat.detach() |
|
|
state_pos = state_pos.detach() |
|
|
mem = mem.detach() |
|
|
if chunk_id < ((len(batch) - 1) // chunk_size + 1) - 4: |
|
|
with torch.no_grad(): |
|
|
for in_chunk_idx in range(chunk_size): |
|
|
i = chunk_id * chunk_size + in_chunk_idx |
|
|
if i >= len(batch): |
|
|
break |
|
|
res, (state_feat, mem) = accelerator.unwrap_model( |
|
|
model |
|
|
)._forward_decoder_step( |
|
|
batch, |
|
|
i, |
|
|
feat_i=feat[i], |
|
|
pos_i=pos[i], |
|
|
shape_i=shape[i], |
|
|
init_state_feat=init_state_feat, |
|
|
init_mem=init_mem, |
|
|
state_feat=state_feat, |
|
|
state_pos=state_pos, |
|
|
mem=mem, |
|
|
) |
|
|
preds.append(res) |
|
|
all_preds.append({k: v.detach() for k, v in res.items()}) |
|
|
chunk.append(batch[i]) |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
loss, loss_details = ( |
|
|
criterion(chunk, preds, camera1=batch[0]["camera_pose"]) |
|
|
if criterion is not None |
|
|
else None |
|
|
) |
|
|
all_loss += float(loss) |
|
|
all_loss_details = merge_chunk_dict( |
|
|
all_loss_details, loss_details, chunk_id * chunk_size |
|
|
) |
|
|
del loss |
|
|
else: |
|
|
for in_chunk_idx in range(chunk_size): |
|
|
i = chunk_id * chunk_size + in_chunk_idx |
|
|
if i >= len(batch): |
|
|
break |
|
|
res, (state_feat, mem) = accelerator.unwrap_model( |
|
|
model |
|
|
)._forward_decoder_step( |
|
|
batch, |
|
|
i, |
|
|
feat_i=feat[i], |
|
|
pos_i=pos[i], |
|
|
shape_i=shape[i], |
|
|
init_state_feat=init_state_feat, |
|
|
init_mem=init_mem, |
|
|
state_feat=state_feat, |
|
|
state_pos=state_pos, |
|
|
mem=mem, |
|
|
) |
|
|
preds.append(res) |
|
|
all_preds.append({k: v.detach() for k, v in res.items()}) |
|
|
chunk.append(batch[i]) |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
loss, loss_details = ( |
|
|
criterion(chunk, preds, camera1=batch[0]["camera_pose"]) |
|
|
if criterion is not None |
|
|
else None |
|
|
) |
|
|
all_loss += float(loss) |
|
|
all_loss_details = merge_chunk_dict( |
|
|
all_loss_details, loss_details, chunk_id * chunk_size |
|
|
) |
|
|
loss_scaler( |
|
|
loss, |
|
|
optimizer, |
|
|
parameters=model.parameters(), |
|
|
update_grad=True, |
|
|
clip_grad=1.0, |
|
|
) |
|
|
optimizer.zero_grad() |
|
|
del loss |
|
|
result = dict( |
|
|
views=batch, |
|
|
pred=all_preds, |
|
|
loss=(all_loss / ((len(batch) - 1) // chunk_size + 1), all_loss_details), |
|
|
already_backprop=True, |
|
|
) |
|
|
return result[ret] if ret else result |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference(groups, model, device, verbose=True): |
|
|
ignore_keys = set( |
|
|
["depthmap", "dataset", "label", "instance", "idx", "true_shape", "rng"] |
|
|
) |
|
|
for view in groups: |
|
|
for name in view.keys(): |
|
|
if name in ignore_keys: |
|
|
continue |
|
|
if isinstance(view[name], tuple) or isinstance(view[name], list): |
|
|
view[name] = [x.to(device, non_blocking=True) for x in view[name]] |
|
|
else: |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
if verbose: |
|
|
print(f">> Inference with model on {len(groups)} image/raymaps") |
|
|
|
|
|
res, state_args = loss_of_one_batch(groups, model, None, None, inference=True) |
|
|
result = to_cpu(res) |
|
|
return result, state_args |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference_step(view, state_args, model, device, verbose=True): |
|
|
ignore_keys = set( |
|
|
["depthmap", "dataset", "label", "instance", "idx", "true_shape", "rng"] |
|
|
) |
|
|
for name in view.keys(): |
|
|
if name in ignore_keys: |
|
|
continue |
|
|
if isinstance(view[name], tuple) or isinstance(view[name], list): |
|
|
view[name] = [x.to(device, non_blocking=True) for x in view[name]] |
|
|
else: |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
state_feat, state_pos, init_state_feat, mem, init_mem = state_args |
|
|
pred, _ = model.inference_step( |
|
|
view, state_feat, state_pos, init_state_feat, mem, init_mem |
|
|
) |
|
|
|
|
|
res = dict(pred=pred) |
|
|
result = to_cpu(res) |
|
|
return result |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference_recurrent(groups, model, device, verbose=True): |
|
|
ignore_keys = set( |
|
|
["depthmap", "dataset", "label", "instance", "idx", "true_shape", "rng"] |
|
|
) |
|
|
for view in groups: |
|
|
for name in view.keys(): |
|
|
if name in ignore_keys: |
|
|
continue |
|
|
if isinstance(view[name], tuple) or isinstance(view[name], list): |
|
|
view[name] = [x.to(device, non_blocking=True) for x in view[name]] |
|
|
else: |
|
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
|
|
if verbose: |
|
|
print(f">> Inference with model on {len(groups)} image/raymaps") |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
preds, batch, state_args = model.forward_recurrent( |
|
|
groups, device, ret_state=True |
|
|
) |
|
|
res = dict(views=batch, pred=preds) |
|
|
result = to_cpu(res) |
|
|
return result, state_args |
|
|
|
|
|
|
|
|
def check_if_same_size(pairs): |
|
|
shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs] |
|
|
shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs] |
|
|
return all(shapes1[0] == s for s in shapes1) and all( |
|
|
shapes2[0] == s for s in shapes2 |
|
|
) |
|
|
|
|
|
|
|
|
def get_pred_pts3d(gt, pred, use_pose=False, inplace=False): |
|
|
if "depth" in pred and "pseudo_focal" in pred: |
|
|
try: |
|
|
pp = gt["camera_intrinsics"][..., :2, 2] |
|
|
except KeyError: |
|
|
pp = None |
|
|
pts3d = depthmap_to_pts3d(**pred, pp=pp) |
|
|
|
|
|
elif "pts3d" in pred: |
|
|
|
|
|
pts3d = pred["pts3d"] |
|
|
|
|
|
elif "pts3d_in_other_view" in pred: |
|
|
|
|
|
assert use_pose is True |
|
|
return ( |
|
|
pred["pts3d_in_other_view"] |
|
|
if inplace |
|
|
else pred["pts3d_in_other_view"].clone() |
|
|
) |
|
|
|
|
|
if use_pose: |
|
|
camera_pose = pred.get("camera_pose") |
|
|
assert camera_pose is not None |
|
|
pts3d = geotrf(camera_pose, pts3d) |
|
|
|
|
|
return pts3d |
|
|
|
|
|
|
|
|
def find_opt_scaling( |
|
|
gt_pts1, |
|
|
gt_pts2, |
|
|
pr_pts1, |
|
|
pr_pts2=None, |
|
|
fit_mode="weiszfeld_stop_grad", |
|
|
valid1=None, |
|
|
valid2=None, |
|
|
): |
|
|
assert gt_pts1.ndim == pr_pts1.ndim == 4 |
|
|
assert gt_pts1.shape == pr_pts1.shape |
|
|
if gt_pts2 is not None: |
|
|
assert gt_pts2.ndim == pr_pts2.ndim == 4 |
|
|
assert gt_pts2.shape == pr_pts2.shape |
|
|
|
|
|
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) |
|
|
nan_gt_pts2 = ( |
|
|
invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None |
|
|
) |
|
|
|
|
|
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) |
|
|
pr_pts2 = ( |
|
|
invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None |
|
|
) |
|
|
|
|
|
all_gt = ( |
|
|
torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) |
|
|
if gt_pts2 is not None |
|
|
else nan_gt_pts1 |
|
|
) |
|
|
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 |
|
|
|
|
|
dot_gt_pr = (all_pr * all_gt).sum(dim=-1) |
|
|
dot_gt_gt = all_gt.square().sum(dim=-1) |
|
|
|
|
|
if fit_mode.startswith("avg"): |
|
|
|
|
|
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) |
|
|
elif fit_mode.startswith("median"): |
|
|
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values |
|
|
elif fit_mode.startswith("weiszfeld"): |
|
|
|
|
|
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) |
|
|
|
|
|
for iter in range(10): |
|
|
|
|
|
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) |
|
|
|
|
|
w = dis.clip_(min=1e-8).reciprocal() |
|
|
|
|
|
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) |
|
|
else: |
|
|
raise ValueError(f"bad {fit_mode=}") |
|
|
|
|
|
if fit_mode.endswith("stop_grad"): |
|
|
scaling = scaling.detach() |
|
|
|
|
|
scaling = scaling.clip(min=1e-3) |
|
|
|
|
|
return scaling |
|
|
|