|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
|
import numpy as np
|
|
|
from .modality import ModalityType, ModalityConfig, ModalityMixer
|
|
|
|
|
|
class TensorOps:
|
|
|
"""Hardware-accelerated tensor operations with modality support"""
|
|
|
|
|
|
def __init__(self, device: Optional[str] = None):
|
|
|
self.device = device
|
|
|
self.modality_mixer = ModalityMixer()
|
|
|
|
|
|
def matmul(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
y: np.ndarray,
|
|
|
x_modality: Optional[ModalityType] = None,
|
|
|
y_modality: Optional[ModalityType] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Matrix multiplication with modality handling"""
|
|
|
if x_modality == y_modality or (x_modality is None and y_modality is None):
|
|
|
|
|
|
return x @ y
|
|
|
else:
|
|
|
|
|
|
return self.modality_mixer.fuse(x, y, x_modality, y_modality)
|
|
|
|
|
|
def conv(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
weight: np.ndarray,
|
|
|
stride: Union[int, Tuple[int, ...]] = 1,
|
|
|
padding: Union[int, Tuple[int, ...]] = 0,
|
|
|
modality: Optional[ModalityType] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Convolution operation with modality-specific behaviors"""
|
|
|
config = ModalityConfig.get_config(modality) if modality else None
|
|
|
|
|
|
if modality == ModalityType.IMAGE:
|
|
|
|
|
|
return self._conv2d(x, weight, stride, padding)
|
|
|
elif modality == ModalityType.AUDIO:
|
|
|
|
|
|
return self._conv1d(x, weight, stride, padding)
|
|
|
elif modality == ModalityType.VIDEO:
|
|
|
|
|
|
return self._conv3d(x, weight, stride, padding)
|
|
|
else:
|
|
|
|
|
|
return self._conv2d(x, weight, stride, padding)
|
|
|
|
|
|
def attention(
|
|
|
self,
|
|
|
q: np.ndarray,
|
|
|
k: np.ndarray,
|
|
|
v: np.ndarray,
|
|
|
modality: Optional[ModalityType] = None,
|
|
|
mask: Optional[np.ndarray] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Attention operation with modality-specific patterns"""
|
|
|
config = ModalityConfig.get_config(modality) if modality else None
|
|
|
|
|
|
|
|
|
pattern = config['attention_pattern'] if config else 'full'
|
|
|
|
|
|
if pattern == 'causal':
|
|
|
|
|
|
if mask is None:
|
|
|
mask = np.triu(np.ones((q.shape[1], k.shape[1])), k=1)
|
|
|
scores = (q @ k.transpose(-2, -1)) / np.sqrt(k.shape[-1])
|
|
|
scores = np.ma.masked_array(scores, mask=mask)
|
|
|
attn = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
|
|
|
return attn @ v
|
|
|
|
|
|
elif pattern == 'local':
|
|
|
|
|
|
window_size = config['block_size'] if config else 8
|
|
|
return self._local_attention(q, k, v, window_size)
|
|
|
|
|
|
elif pattern == 'local3d':
|
|
|
|
|
|
window_size = config['block_size'] if config else 4
|
|
|
return self._local_attention_3d(q, k, v, window_size)
|
|
|
|
|
|
else:
|
|
|
|
|
|
scores = (q @ k.transpose(-2, -1)) / np.sqrt(k.shape[-1])
|
|
|
if mask is not None:
|
|
|
scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
|
attn = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
|
|
|
return attn @ v
|
|
|
|
|
|
def pool(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
kernel_size: Union[int, Tuple[int, ...]],
|
|
|
stride: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
|
mode: str = 'max',
|
|
|
modality: Optional[ModalityType] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Pooling operation with modality-specific behaviors"""
|
|
|
if modality == ModalityType.IMAGE:
|
|
|
return self._pool2d(x, kernel_size, stride, mode)
|
|
|
elif modality == ModalityType.AUDIO:
|
|
|
return self._pool1d(x, kernel_size, stride, mode)
|
|
|
elif modality == ModalityType.VIDEO:
|
|
|
return self._pool3d(x, kernel_size, stride, mode)
|
|
|
else:
|
|
|
return self._pool2d(x, kernel_size, stride, mode)
|
|
|
|
|
|
def normalize(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
modality: Optional[ModalityType] = None,
|
|
|
eps: float = 1e-5
|
|
|
) -> np.ndarray:
|
|
|
"""Normalization with modality-specific behaviors"""
|
|
|
if modality == ModalityType.TEXT:
|
|
|
|
|
|
return (x - x.mean(axis=-1, keepdims=True)) / (x.std(axis=-1, keepdims=True) + eps)
|
|
|
elif modality in [ModalityType.IMAGE, ModalityType.VIDEO]:
|
|
|
|
|
|
return (x - x.mean(axis=(2, 3), keepdims=True)) / (x.std(axis=(2, 3), keepdims=True) + eps)
|
|
|
elif modality == ModalityType.AUDIO:
|
|
|
|
|
|
return (x - x.mean(axis=2, keepdims=True)) / (x.std(axis=2, keepdims=True) + eps)
|
|
|
else:
|
|
|
|
|
|
return (x - x.mean(axis=-1, keepdims=True)) / (x.std(axis=-1, keepdims=True) + eps)
|
|
|
|
|
|
def _conv1d(self, x: np.ndarray, weight: np.ndarray, stride: int = 1, padding: int = 0) -> np.ndarray:
|
|
|
"""1D convolution implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _conv2d(self, x: np.ndarray, weight: np.ndarray, stride: Union[int, Tuple[int, int]] = 1,
|
|
|
padding: Union[int, Tuple[int, int]] = 0) -> np.ndarray:
|
|
|
"""2D convolution implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _conv3d(self, x: np.ndarray, weight: np.ndarray, stride: Union[int, Tuple[int, int, int]] = 1,
|
|
|
padding: Union[int, Tuple[int, int, int]] = 0) -> np.ndarray:
|
|
|
"""3D convolution implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _pool1d(self, x: np.ndarray, kernel_size: int, stride: Optional[int] = None,
|
|
|
mode: str = 'max') -> np.ndarray:
|
|
|
"""1D pooling implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _pool2d(self, x: np.ndarray, kernel_size: Union[int, Tuple[int, int]],
|
|
|
stride: Optional[Union[int, Tuple[int, int]]] = None, mode: str = 'max') -> np.ndarray:
|
|
|
"""2D pooling implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _pool3d(self, x: np.ndarray, kernel_size: Union[int, Tuple[int, int, int]],
|
|
|
stride: Optional[Union[int, Tuple[int, int, int]]] = None, mode: str = 'max') -> np.ndarray:
|
|
|
"""3D pooling implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _local_attention(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, window_size: int) -> np.ndarray:
|
|
|
"""Local attention implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _local_attention_3d(self, q: np.ndarray, k: np.ndarray, v: np.ndarray, window_size: int) -> np.ndarray:
|
|
|
"""3D local attention implementation"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|