Aurora / modality_connector.py
ccloud0525
feat: "first commit"
b40a476
import os
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Resize
from transformers import ViTImageProcessor, ViTModel, BertModel, ViTConfig, BertConfig
from .configuration_aurora import AuroraConfig
class VisionEncoder(nn.Module):
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
def __init__(self, config: AuroraConfig):
super().__init__()
self.processor = UnifiedImageProcessor(config)
self.model = ViTModel(ViTConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
for param in self.model.parameters():
param.requires_grad = False
self.hidden_size = self.model.config.hidden_size
self.output_dim = config.hidden_size
self.num_distill = config.num_distill
self.projection = nn.Linear(self.hidden_size, self.output_dim)
self.target_vision_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
# Cross-attention layer
self.cross_vision = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_vision_cross_layers,
)
def extract_vit_features(self, image_tensor):
"""
Extract image features using ViT
Args:
image_tensor: Preprocessed image tensor with shape [batch_size, 3, H, W]
Returns:
cls_feature: [CLS] token feature with shape [batch_size, hidden_size]
patch_features: Features of all patches with shape [batch_size, num_patches, hidden_size]
"""
outputs = self.model(pixel_values=image_tensor)
last_hidden_state = outputs.last_hidden_state
cls_feature = last_hidden_state[:, 0, :] # [batch_size, hidden_size]
patch_features = last_hidden_state[:, 1:, :] # [batch_size, num_patches, hidden_size]
return cls_feature, patch_features
def forward(self, x, type='pseudo'):
x = self.processor(x, type=type)
_, patch_features = self.extract_vit_features(x)
patch_features = self.projection(patch_features)
target_vision_tokens = self.target_vision_tokens.unsqueeze(0).repeat(patch_features.shape[0], 1, 1)
output_tokens = self.cross_vision(target_vision_tokens, patch_features)
return output_tokens # [batch_size, num_patches, hidden_size]
class UnifiedImageProcessor(nn.Module):
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config')
def __init__(self, config: AuroraConfig):
super().__init__()
# Load ViT preprocessor to get pretrained normalization parameters and target size
self.vit_processor = ViTImageProcessor.from_json_file(os.path.join(self.config_path, 'preprocessor_config.json'))
self.target_size = self.vit_processor.size["height"] # e.g., 224 (default ViT input size)
# Define resizer for pseudo-images (matches real image target size)
self.pseudo_resizer = Resize((self.target_size, self.target_size))
self.token_len = config.token_len
def process_real_image(self, images):
"""Process real images: automatic resizing, cropping, and normalization"""
# Directly use ViTImageProcessor to ensure consistency with pretraining pipeline
inputs = self.vit_processor(images=images, return_tensors="pt")
return inputs["pixel_values"] # Shape: [batch_size, 3, H, W]
def _period_search(self, x):
xf = torch.fft.rfft(x, dim=-1)
# find period by amplitudes
frequency_list = abs(xf).mean(0)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, 1)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period
def process_pseudo_image(self, x):
"""Process pseudo-images (converted from time series): ensure consistent normalization with real images"""
# Segmentation
input_length = x.shape[-1]
period = list(self._period_search(x))[0]
period = period if 0 < period < input_length else self.token_len
if period > input_length:
period = input_length
padding_length = (period - (input_length %
period)) % period
x_pad = F.pad(x, (padding_length, 0))
x_2d = einops.rearrange(x_pad, 'b (p f) -> b 1 f p', f=period)
# 3. Render & Alignment
x_resize = self.pseudo_resizer(x_2d)
image_input = einops.repeat(x_resize, 'b 1 h w -> b c h w', c=3)
return image_input
def forward(self, x, type='pseudo'):
if type == 'pseudo':
return self.process_pseudo_image(x)
else:
return self.process_real_image(x)
class TextEncoder(nn.Module):
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config')
def __init__(self, config: AuroraConfig):
super().__init__()
self.model = BertModel(BertConfig.from_json_file(os.path.join(self.config_path, 'config.json')))
for param in self.model.parameters():
param.requires_grad = False
self.hidden_size = self.model.config.hidden_size
self.output_dim = config.hidden_size
self.num_distill = config.num_distill
self.max_length = 125
self.projection = nn.Linear(self.hidden_size, self.output_dim)
# Define learnable target tokens (shape: [num_distill_tokens, hidden_size])
self.target_text_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim))
self.cross_text = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_text_cross_layers,
)
def extract_bert_features(self, input_dict):
"""Extract and clean BERT features with fixed output shape"""
outputs = self.model(**input_dict)
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
cls_feature = last_hidden_state[:, 0, :] # [batch_size, hidden_size]
token_features = last_hidden_state
# Create mask to exclude [CLS], [SEP], and padding tokens
attention_mask = input_dict["attention_mask"] # [batch_size, seq_len]
batch_size, seq_len = attention_mask.shape
valid_mask = torch.ones_like(attention_mask)
valid_mask[:, 0] = 0 # Exclude [CLS]
for i in range(batch_size):
sep_pos = torch.where(attention_mask[i] == 1)[0][-1]
valid_mask[i, sep_pos] = 0 # Exclude [SEP]
# Apply mask and get valid tokens
valid_token_mask = valid_mask.unsqueeze(-1).expand(-1, -1, self.hidden_size)
clean_token_features = token_features * valid_token_mask
# Convert to fixed shape [batch_size, max_valid_tokens, hidden_size]
fixed_features = torch.zeros(batch_size, self.max_length, self.hidden_size,
device=clean_token_features.device)
valid_counts = []
for i in range(batch_size):
# Get valid tokens (excluding zeros)
valid_tokens = clean_token_features[i][clean_token_features[i].sum(dim=1) != 0]
valid_count = valid_tokens.shape[0]
valid_counts.append(valid_count)
# Truncate if longer than max_length, else pad with zeros
if valid_count > self.max_length:
fixed_features[i] = valid_tokens[:self.max_length]
else:
fixed_features[i, :valid_count] = valid_tokens
return cls_feature, token_features, fixed_features, valid_counts
def forward(self, texts):
"""Return fixed-shape token features [batch_size, max_valid_tokens, hidden_size]"""
_, _, fixed_features, _ = self.extract_bert_features(texts)
fixed_features = self.projection(fixed_features)
target_text_tokens = self.target_text_tokens.unsqueeze(0).repeat(fixed_features.shape[0], 1, 1)
output_tokens = self.cross_text(target_text_tokens, fixed_features)
return output_tokens
class ModalityConnector(nn.Module):
def __init__(self, config: AuroraConfig):
"""
Args:
hidden_size: Feature dimension (must match text/vision feature dimensions)
num_distill_tokens: Unified token count (constant N)
"""
super().__init__()
self.hidden_size = config.hidden_size
# Define learnable target tokens (shape: [num_distill_tokens, hidden_size])
self.connect_text = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_text_connect_layers,
)
self.connect_vision = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout_rate,
batch_first=True,
),
norm=nn.LayerNorm(config.hidden_size),
num_layers=config.num_vision_connect_layers,
)
def forward(self, x, text_features, vision_features):
"""
Distill text and vision tokens to the same count N
Args:
x: Time Series with shape [batch_size, n, hidden_size] (n is time series token count)
text_features: Text features with shape [batch_size, T, hidden_size] (T is text token count)
vision_features: Vision features with shape [batch_size, V, hidden_size] (V is vision token count)
Returns:
text_distilled: Distilled text tokens with shape [batch_size, N, hidden_size]
vision_distilled: Distilled vision tokens with shape [batch_size, N, hidden_size]
"""
if text_features is not None:
from_text = self.connect_text(
x,
text_features
)
else:
from_text = None
from_vision = self.connect_vision(
x,
vision_features
)
return from_text, from_vision