Spaces:
Sleeping
Sleeping
| """ | |
| Debug script mimicking the learned optimizer training loop with stability loss. | |
| Simulates the meta-training loop with inner iterations and the stability loss. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Minimal stubs mirroring your real classes | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class RenderOutput: | |
| color: torch.Tensor # [B, V, C, H, W] | |
| class OptimizerOutput: | |
| context_render_list: List[RenderOutput] | |
| target_render_list: List[RenderOutput] | |
| context_index_list: List[Optional[torch.Tensor]] # list of [B, V] or empty | |
| target_index_list: List[Optional[torch.Tensor]] | |
| def get_render_list(self, input_str: str) -> List[RenderOutput]: | |
| return self.context_render_list if input_str == "context" else self.target_render_list | |
| def get_index_list(self, input_str: str) -> List[torch.Tensor]: | |
| lst = self.context_index_list if input_str == "context" else self.target_index_list | |
| return [x for x in lst if x is not None] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tiny "optimizer network" that produces renders | |
| # across inner iterations β all connected in graph | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class TinyOptimizerNet(nn.Module): | |
| """ | |
| Simulates a learned optimizer that refines a rendering across I inner steps. | |
| Each step: render = prev_render + delta(prev_render, params) | |
| This creates a graph that chains across iterations, just like your real model. | |
| """ | |
| def __init__(self, hidden=16): | |
| super().__init__() | |
| self.refine = nn.Sequential( | |
| nn.Conv2d(3, hidden, 1), | |
| nn.ReLU(), | |
| nn.Conv2d(hidden, 3, 1), | |
| ) | |
| def forward(self, init_render: torch.Tensor, num_inner: int): | |
| """ | |
| init_render: [B, V, C, H, W] | |
| Returns list of renders of length num_inner+1 (init + refined) | |
| """ | |
| renders = [init_render.detach()] # init is detached, like in your code | |
| curr = init_render.detach() | |
| B, V, C, H, W = curr.shape | |
| for _ in range(num_inner): | |
| flat = curr.view(B * V, C, H, W) | |
| delta = self.refine(flat) | |
| curr = curr + delta.view(B, V, C, H, W) # connected graph across iters | |
| renders.append(curr) | |
| return renders | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stability loss (copied from your code) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class LossStability(nn.Module): | |
| def forward(self, optimizer_output: OptimizerOutput, batch: dict) -> torch.Tensor: | |
| total_loss = 0 | |
| for input_str in ["context", "target"]: | |
| render_list = optimizer_output.get_render_list(input_str) | |
| index_list = optimizer_output.get_index_list(input_str) | |
| predictions = [render.color for render in render_list] | |
| predictions = torch.stack(predictions, dim=0) # [I, B, V, C, H, W] | |
| gt = batch[input_str]["image"] # [B, V_all, C, H, W] | |
| if len(index_list) == 0: | |
| loss = torch.abs(predictions - gt).mean(dim=[3, 4, 5]) # [I, B, V] | |
| change_in_loss = loss[1:] - loss[:-1].detach() # [I-1, B, V] | |
| change_in_loss = torch.relu(change_in_loss) | |
| total_loss = total_loss + change_in_loss.sum() | |
| print(f" Stability loss ({input_str}): {total_loss.item():.6f}") | |
| continue | |
| # With index lists | |
| index_list_padded = [index_list[0]] + index_list # I tensors | |
| index_list_t = torch.stack(index_list_padded, dim=0) # [I, B, V] | |
| b = gt.shape[0] | |
| device = gt.device | |
| batch_idx = torch.arange(b, device=device)[None, :, None] | |
| gt_indexed = gt[batch_idx, index_list_t] # [I, B, V, C, H, W] | |
| loss = torch.abs(predictions - gt_indexed).mean(dim=[3, 4, 5]) # [I, B, V] | |
| I, B, V_all = predictions.shape[0], gt.shape[0], gt.shape[1] | |
| loss_full = torch.zeros(I, B, V_all, device=device).scatter(2, index_list_t, loss) | |
| iter_idx = torch.arange(I, device=device).view(-1, 1, 1) | |
| visited = loss_full > 0 | |
| visit_ids = torch.where(visited, iter_idx, torch.full_like(iter_idx, -1)) | |
| last_visit = torch.cummax(visit_ids, dim=0).values | |
| prev_visit = torch.roll(last_visit, shifts=1, dims=0) | |
| prev_visit[0] = -1 | |
| safe_prev = prev_visit.clamp(min=0) | |
| prev_loss = loss_full.gather(0, safe_prev).detach() | |
| has_prev = prev_visit >= 0 | |
| change_in_loss = torch.relu(loss_full - prev_loss) | |
| change_in_loss = change_in_loss * has_prev.float().detach() | |
| total_loss = total_loss + change_in_loss.sum() | |
| print(f" Stability loss ({input_str}): {change_in_loss.sum().item():.6f}") | |
| return total_loss | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Other losses (L1 and a fake LPIPS) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_l1_loss(render_color, gt): | |
| return torch.abs(render_color - gt).mean() | |
| def compute_lpips_loss(render_color, gt): | |
| # Fake LPIPS: just MSE on downsampled version | |
| return ((render_color - gt) ** 2).mean() | |
| def compute_meta_losses(optimizer_output, batch, num_inner): | |
| """Mimics your _calc_opt_loss loop (without stability).""" | |
| opt_loss = 0 | |
| for i in range(num_inner): | |
| pred = optimizer_output.context_render_list[i + 1].color | |
| gt = batch["context"]["image"] | |
| opt_loss = opt_loss + compute_l1_loss(pred, gt) | |
| opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt) | |
| pred = optimizer_output.target_render_list[i + 1].color | |
| gt = batch["target"]["image"] | |
| opt_loss = opt_loss + compute_l1_loss(pred, gt) | |
| opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt) | |
| return opt_loss | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main training loop | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_batch(B=2, V=3, C=3, H=8, W=8, device="cpu"): | |
| return { | |
| "context": {"image": torch.rand(B, V, C, H, W, device=device)}, | |
| "target": {"image": torch.rand(B, V, C, H, W, device=device)}, | |
| } | |
| def run_meta_iteration(net, batch, stability_loss_fn, num_inner=5, use_index_list=False): | |
| B, V, C, H, W = batch["context"]["image"].shape | |
| device = batch["context"]["image"].device | |
| # Simulate init render (detached from network, like 3DGS init) | |
| init_context = torch.rand(B, V, C, H, W, requires_grad=False, device=device) | |
| init_target = torch.rand(B, V, C, H, W, requires_grad=False, device=device) | |
| context_renders = net(init_context, num_inner) | |
| target_renders = net(init_target, num_inner) | |
| context_render_list = [RenderOutput(color=r) for r in context_renders] | |
| target_render_list = [RenderOutput(color=r) for r in target_renders] | |
| if use_index_list: | |
| # Simulate partial view sampling: pick V//2 views each inner iter | |
| num_views = max(1, V // 2) | |
| context_index_list = [ | |
| torch.randint(0, V, (B, num_views), device=device) | |
| for _ in range(num_inner) | |
| ] | |
| target_index_list = [ | |
| torch.randint(0, V, (B, num_views), device=device) | |
| for _ in range(num_inner) | |
| ] | |
| else: | |
| context_index_list = [None] * num_inner | |
| target_index_list = [None] * num_inner | |
| optimizer_output = OptimizerOutput( | |
| context_render_list=context_render_list, | |
| target_render_list=target_render_list, | |
| context_index_list=context_index_list, | |
| target_index_list=target_index_list, | |
| ) | |
| # ββ Meta losses (L1 + LPIPS across inner iters) ββ | |
| meta_loss = compute_meta_losses(optimizer_output, batch, num_inner) | |
| # ββ Stability loss ββ | |
| stab_loss = stability_loss_fn(optimizer_output, batch) | |
| total_loss = meta_loss + 0.01 * stab_loss | |
| return total_loss | |
| def main(): | |
| torch.manual_seed(42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| net = TinyOptimizerNet(hidden=16).to(device) | |
| optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) | |
| stability_loss_fn = LossStability() | |
| NUM_META_STEPS = 5 | |
| NUM_INNER = 5 # inner iterations (like your 6, minus init) | |
| for mode in ["no_index_list", "with_index_list"]: | |
| print(f"\n{'='*50}") | |
| print(f"Mode: {mode}") | |
| print(f"{'='*50}") | |
| use_index = (mode == "with_index_list") | |
| for step in range(NUM_META_STEPS): | |
| batch = make_batch(device=device) | |
| optimizer.zero_grad() | |
| try: | |
| total_loss = run_meta_iteration( | |
| net, batch, stability_loss_fn, | |
| num_inner=NUM_INNER, | |
| use_index_list=use_index | |
| ) | |
| total_loss.backward() | |
| optimizer.step() | |
| print(f" Step {step+1}: total_loss={total_loss.item():.6f} β") | |
| except RuntimeError as e: | |
| print(f" Step {step+1}: ERROR - {e}") | |
| break | |
| if __name__ == "__main__": | |
| main() |