AICME-runtime / sim_priors_pk /models /utils /target_axis_ops.py
cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
"""Utilities for reshaping target-individual axes in FlowPK tensors.
These helpers centralize the repeated reshape patterns used when decoding
multiple target individuals independently.
"""
from __future__ import annotations
from typing import Tuple
import torch
from torchtyping import TensorType
def flatten_target_axis_4d(
x: TensorType["B", "It", "T", "C"]
) -> TensorType["BIt", 1, "T", "C"]:
"""Flatten target axis: ``[B, It, T, C] -> [B*It, 1, T, C]``."""
if x.ndim != 4:
raise ValueError(f"Expected a 4D tensor shaped [B, It, T, C], got ndim={x.ndim}.")
bsz, num_targets, time_steps, channels = x.shape
return x.reshape(bsz * num_targets, 1, time_steps, channels)
def flatten_target_axis_3d(x: TensorType["B", "It", "T"]) -> TensorType["BIt", 1, "T"]:
"""Flatten target axis: ``[B, It, T] -> [B*It, 1, T]``."""
if x.ndim != 3:
raise ValueError(f"Expected a 3D tensor shaped [B, It, T], got ndim={x.ndim}.")
bsz, num_targets, time_steps = x.shape
return x.reshape(bsz * num_targets, 1, time_steps)
def unflatten_target_axis_4d(
x: TensorType["BIt", 1, "T", "C"], batch_size: int, num_targets: int
) -> TensorType["B", "It", "T", "C"]:
"""Unflatten target axis: ``[B*It, 1, T, C] -> [B, It, T, C]``."""
if x.ndim != 4:
raise ValueError(f"Expected a 4D tensor shaped [B*It, 1, T, C], got ndim={x.ndim}.")
flat_batch, singleton, time_steps, channels = x.shape
expected_flat_batch = batch_size * num_targets
if singleton != 1:
raise ValueError(f"Expected singleton target axis of size 1, got {singleton}.")
if flat_batch != expected_flat_batch:
raise ValueError(
f"Expected leading axis {expected_flat_batch} (= {batch_size} * {num_targets}), "
f"got {flat_batch}."
)
return x.reshape(batch_size, num_targets, time_steps, channels)
def flatten_many_target_axis_4d(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Flatten target axis for many 4D tensors in one call."""
return tuple(flatten_target_axis_4d(x) for x in tensors)
def flatten_many_target_axis_3d(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Flatten target axis for many 3D tensors in one call."""
return tuple(flatten_target_axis_3d(x) for x in tensors)