File size: 10,150 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Debug script mimicking the learned optimizer training loop with stability loss.
Simulates the meta-training loop with inner iterations and the stability loss.
"""
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import List, Optional


# ─────────────────────────────────────────────
# Minimal stubs mirroring your real classes
# ─────────────────────────────────────────────

@dataclass
class RenderOutput:
    color: torch.Tensor  # [B, V, C, H, W]


@dataclass
class OptimizerOutput:
    context_render_list: List[RenderOutput]
    target_render_list: List[RenderOutput]
    context_index_list: List[Optional[torch.Tensor]]  # list of [B, V] or empty
    target_index_list: List[Optional[torch.Tensor]]

    def get_render_list(self, input_str: str) -> List[RenderOutput]:
        return self.context_render_list if input_str == "context" else self.target_render_list

    def get_index_list(self, input_str: str) -> List[torch.Tensor]:
        lst = self.context_index_list if input_str == "context" else self.target_index_list
        return [x for x in lst if x is not None]


# ─────────────────────────────────────────────
# Tiny "optimizer network" that produces renders
# across inner iterations β€” all connected in graph
# ─────────────────────────────────────────────

class TinyOptimizerNet(nn.Module):
    """
    Simulates a learned optimizer that refines a rendering across I inner steps.
    Each step: render = prev_render + delta(prev_render, params)
    This creates a graph that chains across iterations, just like your real model.
    """
    def __init__(self, hidden=16):
        super().__init__()
        self.refine = nn.Sequential(
            nn.Conv2d(3, hidden, 1),
            nn.ReLU(),
            nn.Conv2d(hidden, 3, 1),
        )

    def forward(self, init_render: torch.Tensor, num_inner: int):
        """
        init_render: [B, V, C, H, W]
        Returns list of renders of length num_inner+1 (init + refined)
        """
        renders = [init_render.detach()]  # init is detached, like in your code
        curr = init_render.detach()
        B, V, C, H, W = curr.shape
        for _ in range(num_inner):
            flat = curr.view(B * V, C, H, W)
            delta = self.refine(flat)
            curr = curr + delta.view(B, V, C, H, W)  # connected graph across iters
            renders.append(curr)
        return renders


# ─────────────────────────────────────────────
# Stability loss (copied from your code)
# ─────────────────────────────────────────────

class LossStability(nn.Module):
    def forward(self, optimizer_output: OptimizerOutput, batch: dict) -> torch.Tensor:
        total_loss = 0
        for input_str in ["context", "target"]:
            render_list = optimizer_output.get_render_list(input_str)
            index_list = optimizer_output.get_index_list(input_str)

            predictions = [render.color for render in render_list]
            predictions = torch.stack(predictions, dim=0)  # [I, B, V, C, H, W]
            gt = batch[input_str]["image"]  # [B, V_all, C, H, W]

            if len(index_list) == 0:
                loss = torch.abs(predictions - gt).mean(dim=[3, 4, 5])  # [I, B, V]
                change_in_loss = loss[1:] - loss[:-1].detach()  # [I-1, B, V]
                change_in_loss = torch.relu(change_in_loss)
                total_loss = total_loss + change_in_loss.sum()
                print(f"  Stability loss ({input_str}): {total_loss.item():.6f}")
                continue

            # With index lists
            index_list_padded = [index_list[0]] + index_list  # I tensors
            index_list_t = torch.stack(index_list_padded, dim=0)  # [I, B, V]

            b = gt.shape[0]
            device = gt.device
            batch_idx = torch.arange(b, device=device)[None, :, None]
            gt_indexed = gt[batch_idx, index_list_t]  # [I, B, V, C, H, W]

            loss = torch.abs(predictions - gt_indexed).mean(dim=[3, 4, 5])  # [I, B, V]

            I, B, V_all = predictions.shape[0], gt.shape[0], gt.shape[1]
            loss_full = torch.zeros(I, B, V_all, device=device).scatter(2, index_list_t, loss)

            iter_idx = torch.arange(I, device=device).view(-1, 1, 1)
            visited = loss_full > 0
            visit_ids = torch.where(visited, iter_idx, torch.full_like(iter_idx, -1))
            last_visit = torch.cummax(visit_ids, dim=0).values
            prev_visit = torch.roll(last_visit, shifts=1, dims=0)
            prev_visit[0] = -1
            safe_prev = prev_visit.clamp(min=0)
            prev_loss = loss_full.gather(0, safe_prev).detach()
            has_prev = prev_visit >= 0
            change_in_loss = torch.relu(loss_full - prev_loss)
            change_in_loss = change_in_loss * has_prev.float().detach()
            total_loss = total_loss + change_in_loss.sum()
            print(f"  Stability loss ({input_str}): {change_in_loss.sum().item():.6f}")

        return total_loss


# ─────────────────────────────────────────────
# Other losses (L1 and a fake LPIPS)
# ─────────────────────────────────────────────

def compute_l1_loss(render_color, gt):
    return torch.abs(render_color - gt).mean()

def compute_lpips_loss(render_color, gt):
    # Fake LPIPS: just MSE on downsampled version
    return ((render_color - gt) ** 2).mean()

def compute_meta_losses(optimizer_output, batch, num_inner):
    """Mimics your _calc_opt_loss loop (without stability)."""
    opt_loss = 0
    for i in range(num_inner):
        pred = optimizer_output.context_render_list[i + 1].color
        gt = batch["context"]["image"]
        opt_loss = opt_loss + compute_l1_loss(pred, gt)
        opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt)

        pred = optimizer_output.target_render_list[i + 1].color
        gt = batch["target"]["image"]
        opt_loss = opt_loss + compute_l1_loss(pred, gt)
        opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt)
    return opt_loss


# ─────────────────────────────────────────────
# Main training loop
# ─────────────────────────────────────────────

def make_batch(B=2, V=3, C=3, H=8, W=8, device="cpu"):
    return {
        "context": {"image": torch.rand(B, V, C, H, W, device=device)},
        "target":  {"image": torch.rand(B, V, C, H, W, device=device)},
    }


def run_meta_iteration(net, batch, stability_loss_fn, num_inner=5, use_index_list=False):
    B, V, C, H, W = batch["context"]["image"].shape
    device = batch["context"]["image"].device

    # Simulate init render (detached from network, like 3DGS init)
    init_context = torch.rand(B, V, C, H, W, requires_grad=False, device=device)
    init_target  = torch.rand(B, V, C, H, W, requires_grad=False, device=device)

    context_renders = net(init_context, num_inner)
    target_renders  = net(init_target,  num_inner)

    context_render_list = [RenderOutput(color=r) for r in context_renders]
    target_render_list  = [RenderOutput(color=r) for r in target_renders]

    if use_index_list:
        # Simulate partial view sampling: pick V//2 views each inner iter
        num_views = max(1, V // 2)
        context_index_list = [
            torch.randint(0, V, (B, num_views), device=device)
            for _ in range(num_inner)
        ]
        target_index_list = [
            torch.randint(0, V, (B, num_views), device=device)
            for _ in range(num_inner)
        ]
    else:
        context_index_list = [None] * num_inner
        target_index_list  = [None] * num_inner

    optimizer_output = OptimizerOutput(
        context_render_list=context_render_list,
        target_render_list=target_render_list,
        context_index_list=context_index_list,
        target_index_list=target_index_list,
    )

    # ── Meta losses (L1 + LPIPS across inner iters) ──
    meta_loss = compute_meta_losses(optimizer_output, batch, num_inner)

    # ── Stability loss ──
    stab_loss = stability_loss_fn(optimizer_output, batch)

    total_loss = meta_loss + 0.01 * stab_loss

    return total_loss


def main():
    torch.manual_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = TinyOptimizerNet(hidden=16).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    stability_loss_fn = LossStability()

    NUM_META_STEPS = 5
    NUM_INNER = 5  # inner iterations (like your 6, minus init)

    for mode in ["no_index_list", "with_index_list"]:
        print(f"\n{'='*50}")
        print(f"Mode: {mode}")
        print(f"{'='*50}")
        use_index = (mode == "with_index_list")

        for step in range(NUM_META_STEPS):
            batch = make_batch(device=device)
            optimizer.zero_grad()

            try:
                total_loss = run_meta_iteration(
                    net, batch, stability_loss_fn,
                    num_inner=NUM_INNER,
                    use_index_list=use_index
                )
                total_loss.backward()
                optimizer.step()
                print(f"  Step {step+1}: total_loss={total_loss.item():.6f} βœ“")
            except RuntimeError as e:
                print(f"  Step {step+1}: ERROR - {e}")
                break


if __name__ == "__main__":
    main()