|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
from .modeling_clipPT import CLIPVisionTransformer |
|
|
from transformers import CLIPImageProcessor |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from .modeling_qwen2 import Qwen2Model |
|
|
|
|
|
|
|
|
from .modeling_timer import TimerForPrediction |
|
|
|
|
|
class MulTiCastTimerConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
forecasting_length = None, |
|
|
vision_model_name = None, |
|
|
text_model_name = None, |
|
|
vision_model_prompt_len = None, |
|
|
text_model_prompt_len = None, |
|
|
timer_prompt_len = None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.forecasting_length = forecasting_length |
|
|
self.vision_model_name = vision_model_name |
|
|
self.text_model_name = text_model_name |
|
|
|
|
|
self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10 |
|
|
self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4 |
|
|
self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4 |
|
|
|
|
|
class MulTiCastTimerModel(PreTrainedModel): |
|
|
|
|
|
config_class = MulTiCastTimerConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
if config.vision_model_name is None: |
|
|
pass |
|
|
elif config.vision_model_name == 'CLIP': |
|
|
from transformers import AutoModel |
|
|
vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model |
|
|
state_dict = vision_model.state_dict() |
|
|
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} |
|
|
self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len) |
|
|
self.vision_model.load_state_dict(state_dict, strict=False) |
|
|
self.processor = CLIPImageProcessor() |
|
|
for name, param in self.vision_model.named_parameters(): |
|
|
if "encoder.prompts" in name: |
|
|
param.requires_grad = True |
|
|
else: |
|
|
param.requires_grad = False |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
if config.text_model_name is None: |
|
|
pass |
|
|
elif config.text_model_name == 'Qwen': |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct") |
|
|
from transformers import AutoModelForCausalLM |
|
|
text_model = AutoModelForCausalLM.from_pretrained( |
|
|
"Qwen/Qwen2-1.5B-Instruct", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cpu", |
|
|
attn_implementation="sdpa" |
|
|
).model |
|
|
state_dict = text_model.state_dict() |
|
|
self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len) |
|
|
self.text_model.load_state_dict(state_dict, strict=False) |
|
|
for name, param in self.text_model.named_parameters(): |
|
|
if "prompts" in name: |
|
|
param.requires_grad = True |
|
|
else: |
|
|
param.requires_grad = False |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True) |
|
|
state_dict = timer.state_dict() |
|
|
self.timer = TimerForPrediction(timer.config, config.timer_prompt_len) |
|
|
self.timer.load_state_dict(state_dict, strict=False) |
|
|
for name, param in self.timer.named_parameters(): |
|
|
if "model.prompts" in name: |
|
|
param.requires_grad = True |
|
|
else: |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if config.vision_model_name is None: |
|
|
pass |
|
|
else: |
|
|
self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size) |
|
|
|
|
|
|
|
|
if config.text_model_name is None: |
|
|
pass |
|
|
else: |
|
|
self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size) |
|
|
|
|
|
def predict(self, input_ids = None, images = None, texts = None): |
|
|
images = self.processor.preprocess(images)['pixel_values'][0] |
|
|
images = torch.tensor(images) |
|
|
images = images.unsqueeze(0) |
|
|
|
|
|
if self.config.vision_model_name is None and images is None: |
|
|
vision_embedding = None |
|
|
else: |
|
|
vision_output = self.vision_model(images, output_attentions=True) |
|
|
vision_attentions = vision_output.attentions |
|
|
vision_embedding = vision_output.pooler_output |
|
|
vision_embedding = self.vision_interaction_layer(vision_embedding) |
|
|
|
|
|
if self.config.text_model_name is None and all(x is None for x in texts): |
|
|
text_embedding = None |
|
|
else: |
|
|
tokenized_texts = self.tokenizer(texts, return_tensors="pt") |
|
|
text_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_texts["input_ids"][0]) |
|
|
text_output = self.text_model(**tokenized_texts, output_attentions=True) |
|
|
text_attentions = text_output.attentions |
|
|
text_embedding = text_output.last_hidden_state[:, 0 , :] |
|
|
text_embedding = self.text_interaction_layer(text_embedding) |
|
|
|
|
|
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) |
|
|
|
|
|
return { |
|
|
"logits": out.logits, |
|
|
"vision_attentions": vision_attentions, |
|
|
"text_tokens": text_tokens, |
|
|
"text_attentions": text_attentions, |
|
|
"time_series_attentions": out.attentions |
|
|
} |
|
|
|
|
|
def forward(self, input_ids = None, images = None, texts = None, labels = None): |
|
|
if self.config.vision_model_name is None and images is None: |
|
|
vision_embedding = None |
|
|
else: |
|
|
vision_embedding = self.vision_model(images) |
|
|
vision_embedding = vision_embedding.pooler_output |
|
|
vision_embedding = self.vision_interaction_layer(vision_embedding) |
|
|
|
|
|
if self.config.text_model_name is None and all(x is None for x in texts): |
|
|
text_embedding = None |
|
|
else: |
|
|
tokenized_texts = self.tokenizer(texts, return_tensors="pt") |
|
|
text_embedding = self.text_model(**tokenized_texts) |
|
|
text_embedding = text_embedding.last_hidden_state[:, 0 , :] |
|
|
text_embedding = self.text_interaction_layer(text_embedding) |
|
|
|
|
|
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) |
|
|
out = out["logits"] |
|
|
|
|
|
if labels is not None: |
|
|
if self.config.forecasting_length == out.shape[-1]: |
|
|
loss = torch.mean(torch.square(out-labels)) |
|
|
else: |
|
|
loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels)) |
|
|
else: |
|
|
loss = None |
|
|
|
|
|
return { |
|
|
"loss": loss, |
|
|
"logits": out |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
from transformers.utils import cached_file |
|
|
config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
model = MulTiCastTimerModel(config) |
|
|
resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors") |
|
|
state_dict = load_file(resolved_file) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |