|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
|
import numpy as np
|
|
|
from enum import Enum
|
|
|
|
|
|
class ModalityType(Enum):
|
|
|
TEXT = "text"
|
|
|
IMAGE = "image"
|
|
|
AUDIO = "audio"
|
|
|
VIDEO = "video"
|
|
|
GRAPH = "graph"
|
|
|
POINT_CLOUD = "point_cloud"
|
|
|
VOXEL = "voxel"
|
|
|
LATENT = "latent"
|
|
|
EMBEDDING = "embedding"
|
|
|
|
|
|
class ModalityConfig:
|
|
|
"""Configuration for different modalities"""
|
|
|
|
|
|
|
|
|
DEFAULTS = {
|
|
|
ModalityType.TEXT: {
|
|
|
'dims': 1,
|
|
|
'attention_pattern': 'causal',
|
|
|
'position_encoding': 'rotary',
|
|
|
'block_size': 4096
|
|
|
},
|
|
|
ModalityType.IMAGE: {
|
|
|
'dims': 2,
|
|
|
'attention_pattern': 'local',
|
|
|
'position_encoding': '2d_relative',
|
|
|
'block_size': 256
|
|
|
},
|
|
|
ModalityType.AUDIO: {
|
|
|
'dims': 1,
|
|
|
'attention_pattern': 'local',
|
|
|
'position_encoding': 'rotary',
|
|
|
'block_size': 8192
|
|
|
},
|
|
|
ModalityType.VIDEO: {
|
|
|
'dims': 3,
|
|
|
'attention_pattern': 'local3d',
|
|
|
'position_encoding': '3d_relative',
|
|
|
'block_size': 512
|
|
|
},
|
|
|
ModalityType.GRAPH: {
|
|
|
'dims': None,
|
|
|
'attention_pattern': 'graph',
|
|
|
'position_encoding': 'structure',
|
|
|
'block_size': None
|
|
|
},
|
|
|
ModalityType.POINT_CLOUD: {
|
|
|
'dims': 3,
|
|
|
'attention_pattern': 'knn',
|
|
|
'position_encoding': '3d_absolute',
|
|
|
'block_size': 1024
|
|
|
},
|
|
|
ModalityType.VOXEL: {
|
|
|
'dims': 3,
|
|
|
'attention_pattern': 'local3d',
|
|
|
'position_encoding': '3d_relative',
|
|
|
'block_size': 64
|
|
|
},
|
|
|
ModalityType.LATENT: {
|
|
|
'dims': 1,
|
|
|
'attention_pattern': 'full',
|
|
|
'position_encoding': None,
|
|
|
'block_size': None
|
|
|
},
|
|
|
ModalityType.EMBEDDING: {
|
|
|
'dims': 1,
|
|
|
'attention_pattern': 'full',
|
|
|
'position_encoding': None,
|
|
|
'block_size': None
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
|
def get_config(cls, modality: ModalityType) -> Dict[str, Any]:
|
|
|
"""Get configuration for modality"""
|
|
|
return cls.DEFAULTS[modality].copy()
|
|
|
|
|
|
class ModalityMixer:
|
|
|
"""Handles cross-modal operations"""
|
|
|
|
|
|
def __init__(self, fusion_type: str = "additive"):
|
|
|
self.fusion_type = fusion_type
|
|
|
|
|
|
def fuse(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
y: np.ndarray,
|
|
|
x_modality: ModalityType,
|
|
|
y_modality: ModalityType
|
|
|
) -> np.ndarray:
|
|
|
"""Fuse tensors from different modalities"""
|
|
|
if self.fusion_type == "additive":
|
|
|
return x + y
|
|
|
elif self.fusion_type == "multiplicative":
|
|
|
return x * y
|
|
|
elif self.fusion_type == "concatenative":
|
|
|
return np.concatenate([x, y], axis=-1)
|
|
|
elif self.fusion_type == "attention":
|
|
|
|
|
|
q = x @ np.random.randn(x.shape[-1], x.shape[-1])
|
|
|
k = y @ np.random.randn(y.shape[-1], x.shape[-1])
|
|
|
v = y @ np.random.randn(y.shape[-1], x.shape[-1])
|
|
|
|
|
|
scores = q @ k.transpose(-2, -1) / np.sqrt(x.shape[-1])
|
|
|
attn = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
|
|
|
return attn @ v
|
|
|
|
|
|
raise ValueError(f"Unknown fusion type: {self.fusion_type}")
|
|
|
|
|
|
def unfuse(
|
|
|
self,
|
|
|
z: np.ndarray,
|
|
|
x_modality: ModalityType,
|
|
|
y_modality: ModalityType
|
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
"""Separate fused tensor back into modalities"""
|
|
|
if self.fusion_type in ["additive", "multiplicative"]:
|
|
|
|
|
|
return z/2, z/2
|
|
|
elif self.fusion_type == "concatenative":
|
|
|
split_idx = z.shape[-1] // 2
|
|
|
return z[..., :split_idx], z[..., split_idx:]
|
|
|
elif self.fusion_type == "attention":
|
|
|
|
|
|
x_proj = z @ np.random.randn(z.shape[-1], z.shape[-1])
|
|
|
y_proj = z @ np.random.randn(z.shape[-1], z.shape[-1])
|
|
|
return x_proj, y_proj
|
|
|
|