| """Utilities for reshaping target-individual axes in FlowPK tensors. |
| |
| These helpers centralize the repeated reshape patterns used when decoding |
| multiple target individuals independently. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Tuple |
|
|
| import torch |
| from torchtyping import TensorType |
|
|
|
|
| def flatten_target_axis_4d( |
| x: TensorType["B", "It", "T", "C"] |
| ) -> TensorType["BIt", 1, "T", "C"]: |
| """Flatten target axis: ``[B, It, T, C] -> [B*It, 1, T, C]``.""" |
| if x.ndim != 4: |
| raise ValueError(f"Expected a 4D tensor shaped [B, It, T, C], got ndim={x.ndim}.") |
| bsz, num_targets, time_steps, channels = x.shape |
| return x.reshape(bsz * num_targets, 1, time_steps, channels) |
|
|
|
|
| def flatten_target_axis_3d(x: TensorType["B", "It", "T"]) -> TensorType["BIt", 1, "T"]: |
| """Flatten target axis: ``[B, It, T] -> [B*It, 1, T]``.""" |
| if x.ndim != 3: |
| raise ValueError(f"Expected a 3D tensor shaped [B, It, T], got ndim={x.ndim}.") |
| bsz, num_targets, time_steps = x.shape |
| return x.reshape(bsz * num_targets, 1, time_steps) |
|
|
|
|
| def unflatten_target_axis_4d( |
| x: TensorType["BIt", 1, "T", "C"], batch_size: int, num_targets: int |
| ) -> TensorType["B", "It", "T", "C"]: |
| """Unflatten target axis: ``[B*It, 1, T, C] -> [B, It, T, C]``.""" |
| if x.ndim != 4: |
| raise ValueError(f"Expected a 4D tensor shaped [B*It, 1, T, C], got ndim={x.ndim}.") |
| flat_batch, singleton, time_steps, channels = x.shape |
| expected_flat_batch = batch_size * num_targets |
| if singleton != 1: |
| raise ValueError(f"Expected singleton target axis of size 1, got {singleton}.") |
| if flat_batch != expected_flat_batch: |
| raise ValueError( |
| f"Expected leading axis {expected_flat_batch} (= {batch_size} * {num_targets}), " |
| f"got {flat_batch}." |
| ) |
| return x.reshape(batch_size, num_targets, time_steps, channels) |
|
|
|
|
| def flatten_many_target_axis_4d(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: |
| """Flatten target axis for many 4D tensors in one call.""" |
| return tuple(flatten_target_axis_4d(x) for x in tensors) |
|
|
|
|
| def flatten_many_target_axis_3d(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: |
| """Flatten target axis for many 3D tensors in one call.""" |
| return tuple(flatten_target_axis_3d(x) for x in tensors) |
|
|