| import numpy as np
|
| from enum import Enum, auto
|
| from typing import List, Tuple, Optional, Dict, Union
|
| from dataclasses import dataclass
|
|
|
| class ModalityType(Enum):
|
| """Supported modality types for tensors"""
|
| TEXT = "text"
|
| IMAGE = "image"
|
| VIDEO = "video"
|
| AUDIO = "audio"
|
| VISION = "vision"
|
| LATENT = "latent"
|
| EMBEDDING = "embedding"
|
| ATTENTION = "attention"
|
|
|
| @dataclass
|
| class TensorMetadata:
|
| """Metadata for multi-modal tensors"""
|
| modality: ModalityType
|
| shape: Tuple[int, ...]
|
| dtype: np.dtype
|
| channels: int = 1
|
| sampling_rate: Optional[int] = None
|
| frame_rate: Optional[int] = None
|
| sequence_length: Optional[int] = None
|
| spatial_dims: Optional[Tuple[int, ...]] = None
|
|
|
| class BroadcastState:
|
| def __init__(self, driver, prefix: str):
|
| self.driver = driver
|
| self.prefix = prefix
|
| self.counter = 0
|
| self.metadata_cache: Dict[str, TensorMetadata] = {}
|
|
|
| def get_temp_tensor(
|
| self,
|
| data,
|
| name_suffix: str = "",
|
| metadata: Optional[TensorMetadata] = None
|
| ) -> str:
|
| """Store temporary computation results in driver memory with metadata"""
|
| name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
|
| self.counter += 1
|
| self.driver.create_tensor(name, data)
|
|
|
| if metadata:
|
| self.metadata_cache[name] = metadata
|
| if hasattr(self.driver, 'set_tensor_metadata'):
|
| self.driver.set_tensor_metadata(name, metadata)
|
|
|
| return name
|
|
|
| def free_temp_tensor(self, name: str):
|
| """Clean up temporary tensors"""
|
| if self.driver.tensor_exists(name):
|
| self.driver.delete_tensor(name)
|
|
|
| def validate_modality_compatibility(
|
| modalities: List[ModalityType],
|
| shapes: List[Tuple[int, ...]],
|
| metadata_list: List[TensorMetadata]
|
| ) -> bool:
|
| """
|
| Validate if tensors with given modalities can be broadcast together
|
| """
|
|
|
| text_modalities = {ModalityType.TEXT, ModalityType.EMBEDDING}
|
| spatial_modalities = {ModalityType.IMAGE, ModalityType.VISION, ModalityType.VIDEO}
|
| temporal_modalities = {ModalityType.AUDIO, ModalityType.VIDEO}
|
|
|
| unique_modalities = set(modalities)
|
|
|
|
|
| if unique_modalities & text_modalities and unique_modalities & spatial_modalities:
|
|
|
| if ModalityType.ATTENTION not in unique_modalities and \
|
| ModalityType.EMBEDDING not in unique_modalities:
|
| return False
|
|
|
|
|
| if len(unique_modalities & temporal_modalities) > 1:
|
| rates = [m.sampling_rate for m in metadata_list
|
| if m.modality in temporal_modalities]
|
| if not all(r == rates[0] for r in rates):
|
| return False
|
|
|
| return True
|
|
|
| def compute_broadcast_shapes_with_modality(
|
| *shapes: Tuple[int, ...],
|
| metadata_list: Optional[List[TensorMetadata]] = None
|
| ) -> Tuple[Tuple[int, ...], Optional[TensorMetadata]]:
|
| """
|
| Compute broadcast shapes with modality awareness
|
| Returns (broadcast_shape, broadcast_metadata) or (None, None) if not compatible
|
| """
|
| if metadata_list and len(shapes) != len(metadata_list):
|
| raise ValueError("Number of shapes must match number of metadata entries")
|
|
|
| if metadata_list:
|
| modalities = [m.modality for m in metadata_list]
|
| if not validate_modality_compatibility(modalities, list(shapes), metadata_list):
|
| return None, None
|
|
|
|
|
| result = []
|
| for dims in zip(*[reversed(s) for s in shapes]):
|
| dim = max(dims)
|
| if all(d == 1 or d == dim for d in dims):
|
| result.append(dim)
|
| else:
|
| return None, None
|
|
|
| broadcast_shape = tuple(reversed(result))
|
|
|
|
|
| if metadata_list:
|
|
|
| broadcast_metadata = TensorMetadata(
|
| modality=metadata_list[0].modality,
|
| shape=broadcast_shape,
|
| dtype=metadata_list[0].dtype,
|
| channels=max(m.channels for m in metadata_list),
|
| sampling_rate=max((m.sampling_rate or 0) for m in metadata_list),
|
| frame_rate=max((m.frame_rate or 0) for m in metadata_list),
|
| sequence_length=max((m.sequence_length or 0) for m in metadata_list),
|
| spatial_dims=max((m.spatial_dims or (0,)) for m in metadata_list)
|
| )
|
| return broadcast_shape, broadcast_metadata
|
|
|
| return broadcast_shape, None
|
|
|
| def compute_broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
|
| """Legacy compatibility wrapper"""
|
| shape, _ = compute_broadcast_shapes_with_modality(*shapes)
|
| return shape
|
|
|
| class BroadcastModule:
|
| def __init__(self, driver):
|
| self.driver = driver
|
| self.metadata_cache: Dict[str, TensorMetadata] = {}
|
|
|
| def _align_spatial_dims(
|
| self,
|
| tensor_name: str,
|
| target_metadata: TensorMetadata,
|
| state: BroadcastState
|
| ) -> str:
|
| """Align spatial dimensions for image/video tensors"""
|
| if not hasattr(self.driver, 'resize'):
|
| return tensor_name
|
|
|
| tensor = self.driver.get_tensor(tensor_name)
|
| metadata = self._get_tensor_metadata(tensor_name)
|
|
|
| if not metadata or not target_metadata.spatial_dims:
|
| return tensor_name
|
|
|
| if metadata.spatial_dims != target_metadata.spatial_dims:
|
| resized_name = state.get_temp_tensor(
|
| self.driver.resize(
|
| tensor_name,
|
| target_metadata.spatial_dims,
|
| mode='bilinear'
|
| ),
|
| "resized",
|
| target_metadata
|
| )
|
| return resized_name
|
|
|
| return tensor_name
|
|
|
| def _align_sampling_rate(
|
| self,
|
| tensor_name: str,
|
| target_metadata: TensorMetadata,
|
| state: BroadcastState
|
| ) -> str:
|
| """Align sampling rates for audio tensors"""
|
| if not hasattr(self.driver, 'resample'):
|
| return tensor_name
|
|
|
| tensor = self.driver.get_tensor(tensor_name)
|
| metadata = self._get_tensor_metadata(tensor_name)
|
|
|
| if not metadata or not target_metadata.sampling_rate:
|
| return tensor_name
|
|
|
| if metadata.sampling_rate != target_metadata.sampling_rate:
|
| resampled_name = state.get_temp_tensor(
|
| self.driver.resample(
|
| tensor_name,
|
| metadata.sampling_rate,
|
| target_metadata.sampling_rate
|
| ),
|
| "resampled",
|
| target_metadata
|
| )
|
| return resampled_name
|
|
|
| return tensor_name
|
|
|
| def _align_sequence_length(
|
| self,
|
| tensor_name: str,
|
| target_metadata: TensorMetadata,
|
| state: BroadcastState
|
| ) -> str:
|
| """Align sequence lengths for text/embedding tensors"""
|
| tensor = self.driver.get_tensor(tensor_name)
|
| metadata = self._get_tensor_metadata(tensor_name)
|
|
|
| if not metadata or not target_metadata.sequence_length:
|
| return tensor_name
|
|
|
| current_length = metadata.sequence_length
|
| target_length = target_metadata.sequence_length
|
|
|
| if current_length != target_length:
|
| if current_length > target_length:
|
|
|
| sliced_name = state.get_temp_tensor(
|
| self.driver.slice(
|
| tensor_name,
|
| (0, target_length)
|
| ),
|
| "truncated",
|
| target_metadata
|
| )
|
| return sliced_name
|
| else:
|
|
|
| padded_name = state.get_temp_tensor(
|
| self.driver.pad(
|
| tensor_name,
|
| ((0, target_length - current_length),),
|
| mode='constant'
|
| ),
|
| "padded",
|
| target_metadata
|
| )
|
| return padded_name
|
|
|
| return tensor_name
|
|
|
| def _expand_dims(self, tensor_name: str, target_dims: int, state: BroadcastState) -> str:
|
| """Add leading dimensions of size 1 to match target dimensionality"""
|
| tensor = self.driver.get_tensor(tensor_name)
|
| current_dims = len(tensor.shape)
|
| if current_dims < target_dims:
|
| new_shape = (1,) * (target_dims - current_dims) + tensor.shape
|
| metadata = self._get_tensor_metadata(tensor_name)
|
| expanded_name = state.get_temp_tensor(
|
| self.driver.reshape(tensor_name, new_shape),
|
| "expanded",
|
| metadata
|
| )
|
| return expanded_name
|
| return tensor_name
|
|
|
| def _broadcast_to(self, tensor_name: str, target_shape: Tuple[int, ...],
|
| state: BroadcastState) -> str:
|
| """Broadcast tensor to target shape in driver memory"""
|
| tensor = self.driver.get_tensor(tensor_name)
|
| current_shape = tensor.shape
|
|
|
| if current_shape == target_shape:
|
| return tensor_name
|
|
|
|
|
| for c, t in zip(reversed(current_shape), reversed(target_shape)):
|
| if c != 1 and c != t:
|
| raise ValueError(f"Shape {current_shape} cannot be broadcast to {target_shape}")
|
|
|
| broadcast_name = state.get_temp_tensor(
|
| self.driver.broadcast_to(tensor_name, target_shape),
|
| "broadcast"
|
| )
|
| return broadcast_name
|
|
|
| def _get_tensor_metadata(self, tensor_name: str) -> Optional[TensorMetadata]:
|
| """Get metadata for a tensor from driver or cache"""
|
| if tensor_name in self.metadata_cache:
|
| return self.metadata_cache[tensor_name]
|
|
|
| if hasattr(self.driver, 'get_tensor_metadata'):
|
| return self.driver.get_tensor_metadata(tensor_name)
|
|
|
| return None
|
|
|
| def broadcast_tensors(self, *tensor_names: str) -> List[str]:
|
| """
|
| Broadcast tensors to a common shape in driver memory.
|
| Handles multi-modal tensors with metadata preservation.
|
| Returns list of broadcasted tensor names.
|
| """
|
| state = BroadcastState(self.driver, "broadcast")
|
|
|
|
|
| shapes = []
|
| metadata_list = []
|
|
|
| for name in tensor_names:
|
| tensor = self.driver.get_tensor(name)
|
| shapes.append(tensor.shape)
|
| metadata = self._get_tensor_metadata(name)
|
| metadata_list.append(metadata if metadata else TensorMetadata(
|
| modality=ModalityType.LATENT,
|
| shape=tensor.shape,
|
| dtype=tensor.dtype
|
| ))
|
|
|
|
|
| target_shape, target_metadata = compute_broadcast_shapes_with_modality(
|
| *shapes,
|
| metadata_list=metadata_list
|
| )
|
|
|
| if target_shape is None:
|
| raise ValueError(
|
| f"Tensors with shapes {shapes} and modalities "
|
| f"{[m.modality for m in metadata_list]} cannot be broadcast together"
|
| )
|
|
|
|
|
| transformed_names = []
|
| for name, metadata in zip(tensor_names, metadata_list):
|
|
|
| if metadata.modality in {ModalityType.IMAGE, ModalityType.VIDEO}:
|
|
|
| name = self._align_spatial_dims(name, target_metadata, state)
|
| elif metadata.modality in {ModalityType.AUDIO}:
|
|
|
| name = self._align_sampling_rate(name, target_metadata, state)
|
| elif metadata.modality in {ModalityType.TEXT, ModalityType.EMBEDDING}:
|
|
|
| name = self._align_sequence_length(name, target_metadata, state)
|
| transformed_names.append(name)
|
|
|
|
|
| target_dims = len(target_shape)
|
| expanded_names = [
|
| self._expand_dims(name, target_dims, state)
|
| for name in transformed_names
|
| ]
|
|
|
|
|
| result_names = []
|
| for expanded_name, orig_metadata in zip(expanded_names, metadata_list):
|
| broadcast_name = self._broadcast_to(
|
| expanded_name,
|
| target_shape,
|
| state
|
| )
|
|
|
| if target_metadata:
|
| state.metadata_cache[broadcast_name] = target_metadata
|
| result_names.append(broadcast_name)
|
|
|
|
|
| for exp_name, orig_name in zip(expanded_names, tensor_names):
|
| if exp_name != orig_name:
|
| state.free_temp_tensor(exp_name)
|
|
|
| return result_names
|
|
|
| def binary_op_broadcast(self, a_name: str, b_name: str,
|
| op_name: str = "add") -> Tuple[str, str]:
|
| """
|
| Broadcast two tensors for a binary operation.
|
| Returns tuple of broadcasted tensor names.
|
| """
|
| return tuple(self.broadcast_tensors(a_name, b_name))
|
|
|
| def unary_op_broadcast(self, tensor_name: str, target_shape: Tuple[int, ...]) -> str:
|
| """
|
| Broadcast a tensor to a target shape for a unary operation.
|
| Returns broadcasted tensor name.
|
| """
|
| state = BroadcastState(self.driver, f"unary_{tensor_name}")
|
| tensor = self.driver.get_tensor(tensor_name)
|
|
|
|
|
| expanded_name = self._expand_dims(tensor_name, len(target_shape), state)
|
|
|
|
|
| result_name = self._broadcast_to(expanded_name, target_shape, state)
|
|
|
|
|
| if expanded_name != tensor_name:
|
| state.free_temp_tensor(expanded_name)
|
|
|
| return result_name
|
|
|
| class BroadcastBackward:
|
| def __init__(self, driver):
|
| self.driver = driver
|
|
|
| def reduce_gradient(self, grad_name: str, original_shape: Tuple[int, ...]) -> str:
|
| """
|
| Reduce gradient back to original tensor shape after broadcasting.
|
| All operations done in driver memory.
|
| """
|
| state = BroadcastState(self.driver, f"reduce_{grad_name}")
|
| grad = self.driver.get_tensor(grad_name)
|
| grad_shape = grad.shape
|
|
|
|
|
| if grad_shape == original_shape:
|
| return grad_name
|
|
|
|
|
| reduce_dims = []
|
| grad_dims = len(grad_shape)
|
| orig_dims = len(original_shape)
|
|
|
|
|
| if grad_dims > orig_dims:
|
| reduce_dims.extend(range(grad_dims - orig_dims))
|
|
|
|
|
| for i, (orig, grad) in enumerate(zip(reversed(original_shape),
|
| reversed(grad_shape[-orig_dims:]))):
|
| if orig == 1 and grad != 1:
|
| reduce_dims.append(grad_dims - orig_dims + i)
|
|
|
|
|
| if reduce_dims:
|
| reduced_name = state.get_temp_tensor(
|
| self.driver.sum(grad_name, axis=tuple(reduce_dims), keepdims=True),
|
| "reduced"
|
| )
|
|
|
|
|
| result_name = state.get_temp_tensor(
|
| self.driver.reshape(reduced_name, original_shape),
|
| "reshaped"
|
| )
|
| state.free_temp_tensor(reduced_name)
|
| return result_name
|
|
|
| return grad_name
|
|
|
|
|
| """
|
| # Initialize
|
| driver = YourDriver()
|
| broadcast_module = BroadcastModule(driver)
|
| backward_module = BroadcastBackward(driver)
|
|
|
| # Forward pass with broadcasting
|
| a_name = "tensor_a" # shape: (2, 1, 4)
|
| b_name = "tensor_b" # shape: (3, 1)
|
| c_name, d_name = broadcast_module.binary_op_broadcast(a_name, b_name)
|
| # c_name and d_name now have shape (2, 3, 4)
|
|
|
| # Backward pass
|
| grad_name = "output_grad" # shape: (2, 3, 4)
|
| grad_a = backward_module.reduce_gradient(grad_name, (2, 1, 4))
|
| grad_b = backward_module.reduce_gradient(grad_name, (3, 1))
|
| """
|
|
|