INV / helium /broadcast.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from enum import Enum, auto
from typing import List, Tuple, Optional, Dict, Union
from dataclasses import dataclass
class ModalityType(Enum):
"""Supported modality types for tensors"""
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
VISION = "vision"
LATENT = "latent"
EMBEDDING = "embedding"
ATTENTION = "attention"
@dataclass
class TensorMetadata:
"""Metadata for multi-modal tensors"""
modality: ModalityType
shape: Tuple[int, ...]
dtype: np.dtype
channels: int = 1
sampling_rate: Optional[int] = None # For audio
frame_rate: Optional[int] = None # For video
sequence_length: Optional[int] = None # For text/time series
spatial_dims: Optional[Tuple[int, ...]] = None # For image/video
class BroadcastState:
def __init__(self, driver, prefix: str):
self.driver = driver
self.prefix = prefix
self.counter = 0
self.metadata_cache: Dict[str, TensorMetadata] = {}
def get_temp_tensor(
self,
data,
name_suffix: str = "",
metadata: Optional[TensorMetadata] = None
) -> str:
"""Store temporary computation results in driver memory with metadata"""
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
self.counter += 1
self.driver.create_tensor(name, data)
if metadata:
self.metadata_cache[name] = metadata
if hasattr(self.driver, 'set_tensor_metadata'):
self.driver.set_tensor_metadata(name, metadata)
return name
def free_temp_tensor(self, name: str):
"""Clean up temporary tensors"""
if self.driver.tensor_exists(name):
self.driver.delete_tensor(name)
def validate_modality_compatibility(
modalities: List[ModalityType],
shapes: List[Tuple[int, ...]],
metadata_list: List[TensorMetadata]
) -> bool:
"""
Validate if tensors with given modalities can be broadcast together
"""
# Basic modality compatibility rules
text_modalities = {ModalityType.TEXT, ModalityType.EMBEDDING}
spatial_modalities = {ModalityType.IMAGE, ModalityType.VISION, ModalityType.VIDEO}
temporal_modalities = {ModalityType.AUDIO, ModalityType.VIDEO}
unique_modalities = set(modalities)
# Check if mixing text and spatial modalities
if unique_modalities & text_modalities and unique_modalities & spatial_modalities:
# Ensure there's an attention or embedding bridge
if ModalityType.ATTENTION not in unique_modalities and \
ModalityType.EMBEDDING not in unique_modalities:
return False
# Check sampling rate compatibility for temporal modalities
if len(unique_modalities & temporal_modalities) > 1:
rates = [m.sampling_rate for m in metadata_list
if m.modality in temporal_modalities]
if not all(r == rates[0] for r in rates):
return False
return True
def compute_broadcast_shapes_with_modality(
*shapes: Tuple[int, ...],
metadata_list: Optional[List[TensorMetadata]] = None
) -> Tuple[Tuple[int, ...], Optional[TensorMetadata]]:
"""
Compute broadcast shapes with modality awareness
Returns (broadcast_shape, broadcast_metadata) or (None, None) if not compatible
"""
if metadata_list and len(shapes) != len(metadata_list):
raise ValueError("Number of shapes must match number of metadata entries")
if metadata_list:
modalities = [m.modality for m in metadata_list]
if not validate_modality_compatibility(modalities, list(shapes), metadata_list):
return None, None
# Compute basic shape broadcasting
result = []
for dims in zip(*[reversed(s) for s in shapes]):
dim = max(dims)
if all(d == 1 or d == dim for d in dims):
result.append(dim)
else:
return None, None
broadcast_shape = tuple(reversed(result))
# Compute broadcast metadata if provided
if metadata_list:
# Take the highest resolution/quality metadata
broadcast_metadata = TensorMetadata(
modality=metadata_list[0].modality, # Primary modality
shape=broadcast_shape,
dtype=metadata_list[0].dtype,
channels=max(m.channels for m in metadata_list),
sampling_rate=max((m.sampling_rate or 0) for m in metadata_list),
frame_rate=max((m.frame_rate or 0) for m in metadata_list),
sequence_length=max((m.sequence_length or 0) for m in metadata_list),
spatial_dims=max((m.spatial_dims or (0,)) for m in metadata_list)
)
return broadcast_shape, broadcast_metadata
return broadcast_shape, None
def compute_broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
"""Legacy compatibility wrapper"""
shape, _ = compute_broadcast_shapes_with_modality(*shapes)
return shape
class BroadcastModule:
def __init__(self, driver):
self.driver = driver
self.metadata_cache: Dict[str, TensorMetadata] = {}
def _align_spatial_dims(
self,
tensor_name: str,
target_metadata: TensorMetadata,
state: BroadcastState
) -> str:
"""Align spatial dimensions for image/video tensors"""
if not hasattr(self.driver, 'resize'):
return tensor_name
tensor = self.driver.get_tensor(tensor_name)
metadata = self._get_tensor_metadata(tensor_name)
if not metadata or not target_metadata.spatial_dims:
return tensor_name
if metadata.spatial_dims != target_metadata.spatial_dims:
resized_name = state.get_temp_tensor(
self.driver.resize(
tensor_name,
target_metadata.spatial_dims,
mode='bilinear'
),
"resized",
target_metadata
)
return resized_name
return tensor_name
def _align_sampling_rate(
self,
tensor_name: str,
target_metadata: TensorMetadata,
state: BroadcastState
) -> str:
"""Align sampling rates for audio tensors"""
if not hasattr(self.driver, 'resample'):
return tensor_name
tensor = self.driver.get_tensor(tensor_name)
metadata = self._get_tensor_metadata(tensor_name)
if not metadata or not target_metadata.sampling_rate:
return tensor_name
if metadata.sampling_rate != target_metadata.sampling_rate:
resampled_name = state.get_temp_tensor(
self.driver.resample(
tensor_name,
metadata.sampling_rate,
target_metadata.sampling_rate
),
"resampled",
target_metadata
)
return resampled_name
return tensor_name
def _align_sequence_length(
self,
tensor_name: str,
target_metadata: TensorMetadata,
state: BroadcastState
) -> str:
"""Align sequence lengths for text/embedding tensors"""
tensor = self.driver.get_tensor(tensor_name)
metadata = self._get_tensor_metadata(tensor_name)
if not metadata or not target_metadata.sequence_length:
return tensor_name
current_length = metadata.sequence_length
target_length = target_metadata.sequence_length
if current_length != target_length:
if current_length > target_length:
# Truncate
sliced_name = state.get_temp_tensor(
self.driver.slice(
tensor_name,
(0, target_length)
),
"truncated",
target_metadata
)
return sliced_name
else:
# Pad
padded_name = state.get_temp_tensor(
self.driver.pad(
tensor_name,
((0, target_length - current_length),),
mode='constant'
),
"padded",
target_metadata
)
return padded_name
return tensor_name
def _expand_dims(self, tensor_name: str, target_dims: int, state: BroadcastState) -> str:
"""Add leading dimensions of size 1 to match target dimensionality"""
tensor = self.driver.get_tensor(tensor_name)
current_dims = len(tensor.shape)
if current_dims < target_dims:
new_shape = (1,) * (target_dims - current_dims) + tensor.shape
metadata = self._get_tensor_metadata(tensor_name)
expanded_name = state.get_temp_tensor(
self.driver.reshape(tensor_name, new_shape),
"expanded",
metadata
)
return expanded_name
return tensor_name
def _broadcast_to(self, tensor_name: str, target_shape: Tuple[int, ...],
state: BroadcastState) -> str:
"""Broadcast tensor to target shape in driver memory"""
tensor = self.driver.get_tensor(tensor_name)
current_shape = tensor.shape
if current_shape == target_shape:
return tensor_name
# Check if broadcasting is possible
for c, t in zip(reversed(current_shape), reversed(target_shape)):
if c != 1 and c != t:
raise ValueError(f"Shape {current_shape} cannot be broadcast to {target_shape}")
broadcast_name = state.get_temp_tensor(
self.driver.broadcast_to(tensor_name, target_shape),
"broadcast"
)
return broadcast_name
def _get_tensor_metadata(self, tensor_name: str) -> Optional[TensorMetadata]:
"""Get metadata for a tensor from driver or cache"""
if tensor_name in self.metadata_cache:
return self.metadata_cache[tensor_name]
if hasattr(self.driver, 'get_tensor_metadata'):
return self.driver.get_tensor_metadata(tensor_name)
return None
def broadcast_tensors(self, *tensor_names: str) -> List[str]:
"""
Broadcast tensors to a common shape in driver memory.
Handles multi-modal tensors with metadata preservation.
Returns list of broadcasted tensor names.
"""
state = BroadcastState(self.driver, "broadcast")
# Get shapes and metadata from driver memory
shapes = []
metadata_list = []
for name in tensor_names:
tensor = self.driver.get_tensor(name)
shapes.append(tensor.shape)
metadata = self._get_tensor_metadata(name)
metadata_list.append(metadata if metadata else TensorMetadata(
modality=ModalityType.LATENT,
shape=tensor.shape,
dtype=tensor.dtype
))
# Compute target shape with modality awareness
target_shape, target_metadata = compute_broadcast_shapes_with_modality(
*shapes,
metadata_list=metadata_list
)
if target_shape is None:
raise ValueError(
f"Tensors with shapes {shapes} and modalities "
f"{[m.modality for m in metadata_list]} cannot be broadcast together"
)
# Handle modality-specific transforms before broadcasting
transformed_names = []
for name, metadata in zip(tensor_names, metadata_list):
# Apply modality-specific preprocessing
if metadata.modality in {ModalityType.IMAGE, ModalityType.VIDEO}:
# Ensure spatial dimensions are properly aligned
name = self._align_spatial_dims(name, target_metadata, state)
elif metadata.modality in {ModalityType.AUDIO}:
# Resample if needed
name = self._align_sampling_rate(name, target_metadata, state)
elif metadata.modality in {ModalityType.TEXT, ModalityType.EMBEDDING}:
# Pad/truncate sequences if needed
name = self._align_sequence_length(name, target_metadata, state)
transformed_names.append(name)
# Expand dimensions to match target dimensionality
target_dims = len(target_shape)
expanded_names = [
self._expand_dims(name, target_dims, state)
for name in transformed_names
]
# Broadcast each tensor to target shape
result_names = []
for expanded_name, orig_metadata in zip(expanded_names, metadata_list):
broadcast_name = self._broadcast_to(
expanded_name,
target_shape,
state
)
# Update metadata for broadcasted tensor
if target_metadata:
state.metadata_cache[broadcast_name] = target_metadata
result_names.append(broadcast_name)
# Clean up expanded tensors if they were created
for exp_name, orig_name in zip(expanded_names, tensor_names):
if exp_name != orig_name:
state.free_temp_tensor(exp_name)
return result_names
def binary_op_broadcast(self, a_name: str, b_name: str,
op_name: str = "add") -> Tuple[str, str]:
"""
Broadcast two tensors for a binary operation.
Returns tuple of broadcasted tensor names.
"""
return tuple(self.broadcast_tensors(a_name, b_name))
def unary_op_broadcast(self, tensor_name: str, target_shape: Tuple[int, ...]) -> str:
"""
Broadcast a tensor to a target shape for a unary operation.
Returns broadcasted tensor name.
"""
state = BroadcastState(self.driver, f"unary_{tensor_name}")
tensor = self.driver.get_tensor(tensor_name)
# First expand dims if needed
expanded_name = self._expand_dims(tensor_name, len(target_shape), state)
# Then broadcast to target shape
result_name = self._broadcast_to(expanded_name, target_shape, state)
# Clean up if expansion was needed
if expanded_name != tensor_name:
state.free_temp_tensor(expanded_name)
return result_name
class BroadcastBackward:
def __init__(self, driver):
self.driver = driver
def reduce_gradient(self, grad_name: str, original_shape: Tuple[int, ...]) -> str:
"""
Reduce gradient back to original tensor shape after broadcasting.
All operations done in driver memory.
"""
state = BroadcastState(self.driver, f"reduce_{grad_name}")
grad = self.driver.get_tensor(grad_name)
grad_shape = grad.shape
# Nothing to reduce if shapes match
if grad_shape == original_shape:
return grad_name
# Calculate dimensions to sum over
reduce_dims = []
grad_dims = len(grad_shape)
orig_dims = len(original_shape)
# Handle leading dimensions
if grad_dims > orig_dims:
reduce_dims.extend(range(grad_dims - orig_dims))
# Handle size-1 dimensions in original shape
for i, (orig, grad) in enumerate(zip(reversed(original_shape),
reversed(grad_shape[-orig_dims:]))):
if orig == 1 and grad != 1:
reduce_dims.append(grad_dims - orig_dims + i)
# Sum over required dimensions
if reduce_dims:
reduced_name = state.get_temp_tensor(
self.driver.sum(grad_name, axis=tuple(reduce_dims), keepdims=True),
"reduced"
)
# Reshape to original shape
result_name = state.get_temp_tensor(
self.driver.reshape(reduced_name, original_shape),
"reshaped"
)
state.free_temp_tensor(reduced_name)
return result_name
return grad_name
# Example usage:
"""
# Initialize
driver = YourDriver()
broadcast_module = BroadcastModule(driver)
backward_module = BroadcastBackward(driver)
# Forward pass with broadcasting
a_name = "tensor_a" # shape: (2, 1, 4)
b_name = "tensor_b" # shape: (3, 1)
c_name, d_name = broadcast_module.binary_op_broadcast(a_name, b_name)
# c_name and d_name now have shape (2, 3, 4)
# Backward pass
grad_name = "output_grad" # shape: (2, 3, 4)
grad_a = backward_module.reduce_gradient(grad_name, (2, 1, 4))
grad_b = backward_module.reduce_gradient(grad_name, (3, 1))
"""