|
|
import os |
|
|
|
|
|
import einops |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision.transforms import Resize |
|
|
from transformers import ViTImageProcessor, ViTModel, BertModel, ViTConfig, BertConfig |
|
|
|
|
|
from .configuration_aurora import AuroraConfig |
|
|
|
|
|
|
|
|
class VisionEncoder(nn.Module): |
|
|
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config') |
|
|
def __init__(self, config: AuroraConfig): |
|
|
super().__init__() |
|
|
self.processor = UnifiedImageProcessor(config) |
|
|
self.model = ViTModel(ViTConfig.from_json_file(os.path.join(self.config_path, 'config.json'))) |
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
self.hidden_size = self.model.config.hidden_size |
|
|
self.output_dim = config.hidden_size |
|
|
self.num_distill = config.num_distill |
|
|
|
|
|
self.projection = nn.Linear(self.hidden_size, self.output_dim) |
|
|
|
|
|
self.target_vision_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim)) |
|
|
|
|
|
|
|
|
self.cross_vision = nn.TransformerDecoder( |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_vision_cross_layers, |
|
|
) |
|
|
|
|
|
def extract_vit_features(self, image_tensor): |
|
|
""" |
|
|
Extract image features using ViT |
|
|
Args: |
|
|
image_tensor: Preprocessed image tensor with shape [batch_size, 3, H, W] |
|
|
Returns: |
|
|
cls_feature: [CLS] token feature with shape [batch_size, hidden_size] |
|
|
patch_features: Features of all patches with shape [batch_size, num_patches, hidden_size] |
|
|
""" |
|
|
outputs = self.model(pixel_values=image_tensor) |
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
|
|
cls_feature = last_hidden_state[:, 0, :] |
|
|
|
|
|
patch_features = last_hidden_state[:, 1:, :] |
|
|
|
|
|
return cls_feature, patch_features |
|
|
|
|
|
def forward(self, x, type='pseudo'): |
|
|
x = self.processor(x, type=type) |
|
|
_, patch_features = self.extract_vit_features(x) |
|
|
patch_features = self.projection(patch_features) |
|
|
target_vision_tokens = self.target_vision_tokens.unsqueeze(0).repeat(patch_features.shape[0], 1, 1) |
|
|
output_tokens = self.cross_vision(target_vision_tokens, patch_features) |
|
|
return output_tokens |
|
|
|
|
|
|
|
|
class UnifiedImageProcessor(nn.Module): |
|
|
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vit_config') |
|
|
def __init__(self, config: AuroraConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.vit_processor = ViTImageProcessor.from_json_file(os.path.join(self.config_path, 'preprocessor_config.json')) |
|
|
self.target_size = self.vit_processor.size["height"] |
|
|
|
|
|
|
|
|
self.pseudo_resizer = Resize((self.target_size, self.target_size)) |
|
|
|
|
|
self.token_len = config.token_len |
|
|
|
|
|
def process_real_image(self, images): |
|
|
"""Process real images: automatic resizing, cropping, and normalization""" |
|
|
|
|
|
inputs = self.vit_processor(images=images, return_tensors="pt") |
|
|
return inputs["pixel_values"] |
|
|
|
|
|
def _period_search(self, x): |
|
|
xf = torch.fft.rfft(x, dim=-1) |
|
|
|
|
|
frequency_list = abs(xf).mean(0) |
|
|
frequency_list[0] = 0 |
|
|
_, top_list = torch.topk(frequency_list, 1) |
|
|
top_list = top_list.detach().cpu().numpy() |
|
|
period = x.shape[1] // top_list |
|
|
return period |
|
|
|
|
|
def process_pseudo_image(self, x): |
|
|
"""Process pseudo-images (converted from time series): ensure consistent normalization with real images""" |
|
|
|
|
|
|
|
|
input_length = x.shape[-1] |
|
|
period = list(self._period_search(x))[0] |
|
|
period = period if 0 < period < input_length else self.token_len |
|
|
if period > input_length: |
|
|
period = input_length |
|
|
|
|
|
padding_length = (period - (input_length % |
|
|
period)) % period |
|
|
x_pad = F.pad(x, (padding_length, 0)) |
|
|
x_2d = einops.rearrange(x_pad, 'b (p f) -> b 1 f p', f=period) |
|
|
|
|
|
|
|
|
x_resize = self.pseudo_resizer(x_2d) |
|
|
image_input = einops.repeat(x_resize, 'b 1 h w -> b c h w', c=3) |
|
|
return image_input |
|
|
|
|
|
def forward(self, x, type='pseudo'): |
|
|
if type == 'pseudo': |
|
|
return self.process_pseudo_image(x) |
|
|
else: |
|
|
return self.process_real_image(x) |
|
|
|
|
|
|
|
|
class TextEncoder(nn.Module): |
|
|
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config') |
|
|
def __init__(self, config: AuroraConfig): |
|
|
super().__init__() |
|
|
self.model = BertModel(BertConfig.from_json_file(os.path.join(self.config_path, 'config.json'))) |
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
self.hidden_size = self.model.config.hidden_size |
|
|
self.output_dim = config.hidden_size |
|
|
self.num_distill = config.num_distill |
|
|
self.max_length = 125 |
|
|
|
|
|
self.projection = nn.Linear(self.hidden_size, self.output_dim) |
|
|
|
|
|
|
|
|
self.target_text_tokens = nn.Parameter(torch.randn(self.num_distill, self.output_dim)) |
|
|
|
|
|
self.cross_text = nn.TransformerDecoder( |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_text_cross_layers, |
|
|
) |
|
|
|
|
|
def extract_bert_features(self, input_dict): |
|
|
"""Extract and clean BERT features with fixed output shape""" |
|
|
outputs = self.model(**input_dict) |
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
cls_feature = last_hidden_state[:, 0, :] |
|
|
token_features = last_hidden_state |
|
|
|
|
|
|
|
|
attention_mask = input_dict["attention_mask"] |
|
|
batch_size, seq_len = attention_mask.shape |
|
|
valid_mask = torch.ones_like(attention_mask) |
|
|
valid_mask[:, 0] = 0 |
|
|
|
|
|
for i in range(batch_size): |
|
|
sep_pos = torch.where(attention_mask[i] == 1)[0][-1] |
|
|
valid_mask[i, sep_pos] = 0 |
|
|
|
|
|
|
|
|
valid_token_mask = valid_mask.unsqueeze(-1).expand(-1, -1, self.hidden_size) |
|
|
clean_token_features = token_features * valid_token_mask |
|
|
|
|
|
|
|
|
fixed_features = torch.zeros(batch_size, self.max_length, self.hidden_size, |
|
|
device=clean_token_features.device) |
|
|
valid_counts = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
|
|
|
valid_tokens = clean_token_features[i][clean_token_features[i].sum(dim=1) != 0] |
|
|
valid_count = valid_tokens.shape[0] |
|
|
valid_counts.append(valid_count) |
|
|
|
|
|
|
|
|
if valid_count > self.max_length: |
|
|
fixed_features[i] = valid_tokens[:self.max_length] |
|
|
else: |
|
|
fixed_features[i, :valid_count] = valid_tokens |
|
|
|
|
|
return cls_feature, token_features, fixed_features, valid_counts |
|
|
|
|
|
def forward(self, texts): |
|
|
"""Return fixed-shape token features [batch_size, max_valid_tokens, hidden_size]""" |
|
|
_, _, fixed_features, _ = self.extract_bert_features(texts) |
|
|
fixed_features = self.projection(fixed_features) |
|
|
|
|
|
target_text_tokens = self.target_text_tokens.unsqueeze(0).repeat(fixed_features.shape[0], 1, 1) |
|
|
|
|
|
output_tokens = self.cross_text(target_text_tokens, fixed_features) |
|
|
return output_tokens |
|
|
|
|
|
|
|
|
class ModalityConnector(nn.Module): |
|
|
def __init__(self, config: AuroraConfig): |
|
|
""" |
|
|
Args: |
|
|
hidden_size: Feature dimension (must match text/vision feature dimensions) |
|
|
num_distill_tokens: Unified token count (constant N) |
|
|
""" |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
|
|
|
self.connect_text = nn.TransformerDecoder( |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_text_connect_layers, |
|
|
) |
|
|
|
|
|
self.connect_vision = nn.TransformerDecoder( |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout_rate, |
|
|
batch_first=True, |
|
|
), |
|
|
norm=nn.LayerNorm(config.hidden_size), |
|
|
num_layers=config.num_vision_connect_layers, |
|
|
) |
|
|
|
|
|
def forward(self, x, text_features, vision_features): |
|
|
""" |
|
|
Distill text and vision tokens to the same count N |
|
|
Args: |
|
|
x: Time Series with shape [batch_size, n, hidden_size] (n is time series token count) |
|
|
text_features: Text features with shape [batch_size, T, hidden_size] (T is text token count) |
|
|
vision_features: Vision features with shape [batch_size, V, hidden_size] (V is vision token count) |
|
|
Returns: |
|
|
text_distilled: Distilled text tokens with shape [batch_size, N, hidden_size] |
|
|
vision_distilled: Distilled vision tokens with shape [batch_size, N, hidden_size] |
|
|
""" |
|
|
if text_features is not None: |
|
|
from_text = self.connect_text( |
|
|
x, |
|
|
text_features |
|
|
) |
|
|
else: |
|
|
from_text = None |
|
|
|
|
|
from_vision = self.connect_vision( |
|
|
x, |
|
|
vision_features |
|
|
) |
|
|
|
|
|
return from_text, from_vision |
|
|
|