INV / helium /modality.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Dict, List, Optional, Union, Any, Tuple
import numpy as np
from enum import Enum
class ModalityType(Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
GRAPH = "graph"
POINT_CLOUD = "point_cloud"
VOXEL = "voxel"
LATENT = "latent"
EMBEDDING = "embedding"
class ModalityConfig:
"""Configuration for different modalities"""
# Default configurations for each modality
DEFAULTS = {
ModalityType.TEXT: {
'dims': 1, # sequence dimension
'attention_pattern': 'causal',
'position_encoding': 'rotary',
'block_size': 4096
},
ModalityType.IMAGE: {
'dims': 2, # height, width
'attention_pattern': 'local',
'position_encoding': '2d_relative',
'block_size': 256 # 16x16 patches
},
ModalityType.AUDIO: {
'dims': 1, # time dimension
'attention_pattern': 'local',
'position_encoding': 'rotary',
'block_size': 8192 # ~10 seconds at 16kHz
},
ModalityType.VIDEO: {
'dims': 3, # time, height, width
'attention_pattern': 'local3d',
'position_encoding': '3d_relative',
'block_size': 512 # 8x8x8 cube
},
ModalityType.GRAPH: {
'dims': None, # adjacency based
'attention_pattern': 'graph',
'position_encoding': 'structure',
'block_size': None
},
ModalityType.POINT_CLOUD: {
'dims': 3, # x, y, z
'attention_pattern': 'knn',
'position_encoding': '3d_absolute',
'block_size': 1024 # points per block
},
ModalityType.VOXEL: {
'dims': 3, # x, y, z
'attention_pattern': 'local3d',
'position_encoding': '3d_relative',
'block_size': 64 # 4x4x4 cube
},
ModalityType.LATENT: {
'dims': 1, # latent dimension
'attention_pattern': 'full',
'position_encoding': None,
'block_size': None
},
ModalityType.EMBEDDING: {
'dims': 1, # embedding dimension
'attention_pattern': 'full',
'position_encoding': None,
'block_size': None
}
}
@classmethod
def get_config(cls, modality: ModalityType) -> Dict[str, Any]:
"""Get configuration for modality"""
return cls.DEFAULTS[modality].copy()
class ModalityMixer:
"""Handles cross-modal operations"""
def __init__(self, fusion_type: str = "additive"):
self.fusion_type = fusion_type
def fuse(
self,
x: np.ndarray,
y: np.ndarray,
x_modality: ModalityType,
y_modality: ModalityType
) -> np.ndarray:
"""Fuse tensors from different modalities"""
if self.fusion_type == "additive":
return x + y
elif self.fusion_type == "multiplicative":
return x * y
elif self.fusion_type == "concatenative":
return np.concatenate([x, y], axis=-1)
elif self.fusion_type == "attention":
# Cross-attention between modalities
q = x @ np.random.randn(x.shape[-1], x.shape[-1]) # Learned projection
k = y @ np.random.randn(y.shape[-1], x.shape[-1])
v = y @ np.random.randn(y.shape[-1], x.shape[-1])
scores = q @ k.transpose(-2, -1) / np.sqrt(x.shape[-1])
attn = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
return attn @ v
raise ValueError(f"Unknown fusion type: {self.fusion_type}")
def unfuse(
self,
z: np.ndarray,
x_modality: ModalityType,
y_modality: ModalityType
) -> Tuple[np.ndarray, np.ndarray]:
"""Separate fused tensor back into modalities"""
if self.fusion_type in ["additive", "multiplicative"]:
# Can't perfectly separate, return equal split
return z/2, z/2
elif self.fusion_type == "concatenative":
split_idx = z.shape[-1] // 2
return z[..., :split_idx], z[..., split_idx:]
elif self.fusion_type == "attention":
# Project back to original modalities
x_proj = z @ np.random.randn(z.shape[-1], z.shape[-1])
y_proj = z @ np.random.randn(z.shape[-1], z.shape[-1])
return x_proj, y_proj