from typing import Dict, List, Optional, Union, Any, Tuple import numpy as np from enum import Enum class ModalityType(Enum): TEXT = "text" IMAGE = "image" AUDIO = "audio" VIDEO = "video" GRAPH = "graph" POINT_CLOUD = "point_cloud" VOXEL = "voxel" LATENT = "latent" EMBEDDING = "embedding" class ModalityConfig: """Configuration for different modalities""" # Default configurations for each modality DEFAULTS = { ModalityType.TEXT: { 'dims': 1, # sequence dimension 'attention_pattern': 'causal', 'position_encoding': 'rotary', 'block_size': 4096 }, ModalityType.IMAGE: { 'dims': 2, # height, width 'attention_pattern': 'local', 'position_encoding': '2d_relative', 'block_size': 256 # 16x16 patches }, ModalityType.AUDIO: { 'dims': 1, # time dimension 'attention_pattern': 'local', 'position_encoding': 'rotary', 'block_size': 8192 # ~10 seconds at 16kHz }, ModalityType.VIDEO: { 'dims': 3, # time, height, width 'attention_pattern': 'local3d', 'position_encoding': '3d_relative', 'block_size': 512 # 8x8x8 cube }, ModalityType.GRAPH: { 'dims': None, # adjacency based 'attention_pattern': 'graph', 'position_encoding': 'structure', 'block_size': None }, ModalityType.POINT_CLOUD: { 'dims': 3, # x, y, z 'attention_pattern': 'knn', 'position_encoding': '3d_absolute', 'block_size': 1024 # points per block }, ModalityType.VOXEL: { 'dims': 3, # x, y, z 'attention_pattern': 'local3d', 'position_encoding': '3d_relative', 'block_size': 64 # 4x4x4 cube }, ModalityType.LATENT: { 'dims': 1, # latent dimension 'attention_pattern': 'full', 'position_encoding': None, 'block_size': None }, ModalityType.EMBEDDING: { 'dims': 1, # embedding dimension 'attention_pattern': 'full', 'position_encoding': None, 'block_size': None } } @classmethod def get_config(cls, modality: ModalityType) -> Dict[str, Any]: """Get configuration for modality""" return cls.DEFAULTS[modality].copy() class ModalityMixer: """Handles cross-modal operations""" def __init__(self, fusion_type: str = "additive"): self.fusion_type = fusion_type def fuse( self, x: np.ndarray, y: np.ndarray, x_modality: ModalityType, y_modality: ModalityType ) -> np.ndarray: """Fuse tensors from different modalities""" if self.fusion_type == "additive": return x + y elif self.fusion_type == "multiplicative": return x * y elif self.fusion_type == "concatenative": return np.concatenate([x, y], axis=-1) elif self.fusion_type == "attention": # Cross-attention between modalities q = x @ np.random.randn(x.shape[-1], x.shape[-1]) # Learned projection k = y @ np.random.randn(y.shape[-1], x.shape[-1]) v = y @ np.random.randn(y.shape[-1], x.shape[-1]) scores = q @ k.transpose(-2, -1) / np.sqrt(x.shape[-1]) attn = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True) return attn @ v raise ValueError(f"Unknown fusion type: {self.fusion_type}") def unfuse( self, z: np.ndarray, x_modality: ModalityType, y_modality: ModalityType ) -> Tuple[np.ndarray, np.ndarray]: """Separate fused tensor back into modalities""" if self.fusion_type in ["additive", "multiplicative"]: # Can't perfectly separate, return equal split return z/2, z/2 elif self.fusion_type == "concatenative": split_idx = z.shape[-1] // 2 return z[..., :split_idx], z[..., split_idx:] elif self.fusion_type == "attention": # Project back to original modalities x_proj = z @ np.random.randn(z.shape[-1], z.shape[-1]) y_proj = z @ np.random.randn(z.shape[-1], z.shape[-1]) return x_proj, y_proj