| |
| |
| """ |
| Time Language Model (TLM) for inference. |
| A multimodal model that combines time series data with language model for time series question answering. |
| """ |
| import os |
| import json |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from safetensors.torch import load_file |
| from models.TimeSeriesEncoder import Model |
| from models.ITFormer import ITFormer |
| from models.QFormerAdapter import QFormerAdapter |
| from accelerate import Accelerator |
|
|
| accelerator = Accelerator() |
|
|
| LORA_STATE_MARKERS = ( |
| ".lora_A.", |
| ".lora_B.", |
| ".lora_embedding_A.", |
| ".lora_embedding_B.", |
| ) |
|
|
|
|
| class TLMConfig(PretrainedConfig): |
| """Configuration class for Time Language Model.""" |
| model_type = "vlm_model" |
| |
| def __init__(self, llm_model_path='LLM/Qwen2.5-0.5B-Instruct', |
| freeze_ts_model=True, |
| ts_pad_num=25, |
| llm_attn_implementation=None, |
| llm_torch_dtype=None, |
| use_lora=False, |
| lora_r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| lora_target_modules=None, |
| gradient_checkpointing=False, |
| **kwargs): |
| """Initialize TLM configuration. |
| |
| Args: |
| llm_model_path: Path to the language model |
| freeze_ts_model: Whether to freeze time series model parameters |
| ts_pad_num: Number of time series padding tokens |
| **kwargs: Additional configuration parameters |
| """ |
| self.llm_model_path = llm_model_path |
| self.freeze_ts_model = freeze_ts_model |
| self.ts_pad_num = ts_pad_num |
| self.llm_attn_implementation = llm_attn_implementation |
| self.llm_torch_dtype = llm_torch_dtype |
| self.use_lora = use_lora |
| self.lora_r = lora_r |
| self.lora_alpha = lora_alpha |
| self.lora_dropout = lora_dropout |
| self.lora_target_modules = lora_target_modules or [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ] |
| self.gradient_checkpointing = gradient_checkpointing |
| super().__init__(**kwargs) |
|
|
|
|
| class TLM(PreTrainedModel, GenerationMixin): |
| """Time Language Model for inference.""" |
| config_class = TLMConfig |
|
|
| def state_dict(self, *args, **kwargs): |
| """Return checkpoint weights without the frozen base LLM. |
| |
| The frozen base Qwen weights are reloaded from config.llm_model_path. |
| Keep only the trainable LoRA matrices under llm_model.*. |
| """ |
| state_dict = super().state_dict(*args, **kwargs) |
| return { |
| key: value |
| for key, value in state_dict.items() |
| if not key.startswith("llm_model.") |
| or any(marker in key for marker in LORA_STATE_MARKERS) |
| } |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, config=None, **kwargs): |
| """Load model from pretrained checkpoint. |
| |
| Args: |
| pretrained_model_name_or_path: Path to the checkpoint |
| config: Model configuration |
| **kwargs: Additional arguments, including ts_config |
| |
| Returns: |
| TLM: Loaded model instance |
| """ |
| if not os.path.exists(pretrained_model_name_or_path): |
| raise ValueError(f"Checkpoint path does not exist: {pretrained_model_name_or_path}") |
|
|
| |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| config_dict = json.load(f) |
| if config is None: |
| config = TLMConfig(**config_dict) |
| else: |
| if config is None: |
| config = TLMConfig() |
|
|
| |
| model = cls(config, **kwargs) |
|
|
| |
| model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
| if not os.path.exists(model_path): |
| model_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
|
|
| state_dict = None |
| |
| if os.path.exists(model_path): |
| if accelerator.is_main_process: |
| print(f"Loading model weights from: {model_path}") |
| if model_path.endswith('.safetensors'): |
| state_dict = load_file(model_path) |
| else: |
| state_dict = torch.load(model_path, map_location='cpu') |
| else: |
| |
| all_files = os.listdir(pretrained_model_name_or_path) |
| safetensors_files = [f for f in all_files if f.startswith('model-') and f.endswith('.safetensors')] |
| safetensors_files.sort() |
| if safetensors_files: |
| if accelerator.is_main_process: |
| print(f"Loading split safetensors from: {pretrained_model_name_or_path}") |
| state_dict = {} |
| for fname in safetensors_files: |
| fpath = os.path.join(pretrained_model_name_or_path, fname) |
| part = load_file(fpath) |
| state_dict.update(part) |
| if accelerator.is_main_process: |
| print(f"Successfully loaded {len(safetensors_files)} split safetensors files.") |
| if state_dict is not None: |
| |
| ignored_llm_weights = {} |
| other_weights = {} |
| for k, v in state_dict.items(): |
| is_lora_weight = any(marker in k for marker in LORA_STATE_MARKERS) |
| if k.startswith('llm_model.') and not is_lora_weight: |
| ignored_llm_weights[k] = v |
| else: |
| other_weights[k] = v |
| if accelerator.is_main_process: |
| lora_count = sum( |
| any(marker in key for marker in LORA_STATE_MARKERS) |
| for key in other_weights |
| ) |
| print(f"Found {len(ignored_llm_weights)} frozen LLM weights (will be ignored)") |
| print(f"Found {lora_count} LoRA tensors") |
| print(f"Found {len(other_weights) - lora_count} non-LLM tensors") |
| checkpoint_has_lora = any( |
| any(marker in key for marker in LORA_STATE_MARKERS) |
| for key in other_weights |
| ) |
| model_has_lora = any( |
| "lora_" in name for name, _ in model.llm_model.named_parameters() |
| ) |
| if checkpoint_has_lora and not model_has_lora: |
| raise ValueError( |
| "The checkpoint contains LoRA matrices, but the model was " |
| "constructed with use_lora=False." |
| ) |
| if getattr(model.config, "use_lora", False) and not checkpoint_has_lora: |
| raise ValueError( |
| "The model was constructed with use_lora=True, but the " |
| "checkpoint does not contain LoRA matrices." |
| ) |
| missing_keys, unexpected_keys = model.load_state_dict(other_weights, strict=False) |
| |
| llm_missing_keys = [ |
| k |
| for k in missing_keys |
| if k.startswith('llm_model.') |
| and not any(marker in k for marker in LORA_STATE_MARKERS) |
| ] |
| non_llm_missing_keys = [k for k in missing_keys if not k.startswith('llm_model.')] |
| missing_lora_keys = [ |
| k |
| for k in missing_keys |
| if any(marker in k for marker in LORA_STATE_MARKERS) |
| ] |
| if llm_missing_keys and accelerator.is_main_process: |
| print(f"LLM missing keys (ignored): {len(llm_missing_keys)} keys") |
| if missing_lora_keys: |
| raise ValueError(f"Missing LoRA checkpoint keys: {missing_lora_keys}") |
| if non_llm_missing_keys and accelerator.is_main_process: |
| print(f"Non-LLM missing keys: {non_llm_missing_keys}") |
| if unexpected_keys and accelerator.is_main_process: |
| print(f"Unexpected keys: {unexpected_keys}") |
| else: |
| if accelerator.is_main_process: |
| print(f"Warning: No model weights found at {model_path} or in split safetensors.") |
|
|
| return model |
|
|
| def __init__(self, config, ts_config=None): |
| """Initialize TLM model. |
| |
| Args: |
| config: TLM configuration |
| ts_config: Optional time series configuration (args) |
| """ |
| super().__init__(config) |
| self.config = config |
| |
| if ts_config is None: |
| |
| class DefaultTSConfig: |
| def __init__(self): |
| self.model = 'TimeSeriesEncoder' |
| self.d_model = 512 |
| self.n_heads = 8 |
| self.e_layers = 4 |
| self.patch_len = 60 |
| self.stride = 60 |
| self.input_len = 600 |
| self.dropout = 0.1 |
| self.it_d_model = 896 |
| self.it_n_heads = 16 |
| self.it_layers = 2 |
| self.it_dropout = 0.1 |
| self.prefix_num = 25 |
| self.adapter_type = 'itformer' |
| ts_config = DefaultTSConfig() |
| |
| self.ts_config = ts_config |
| |
| |
| if hasattr(self.ts_config, 'ts_pad_num') and not hasattr(self.ts_config, 'prefix_num'): |
| setattr(self.ts_config, 'prefix_num', self.ts_config.ts_pad_num) |
| elif hasattr(self.ts_config, 'prefix_num') and not hasattr(self.ts_config, 'ts_pad_num'): |
| setattr(self.ts_config, 'ts_pad_num', self.ts_config.prefix_num) |
| |
| |
| try: |
| llm_load_kwargs = {} |
| attn_impl = getattr(self.config, 'llm_attn_implementation', None) |
| dtype_name = getattr(self.config, 'llm_torch_dtype', None) |
| dtype_map = { |
| "float16": torch.float16, |
| "fp16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "bf16": torch.bfloat16, |
| "float32": torch.float32, |
| "fp32": torch.float32, |
| } |
| if dtype_name: |
| normalized_dtype = str(dtype_name).lower() |
| if normalized_dtype not in dtype_map: |
| raise ValueError(f"Unsupported llm_torch_dtype: {dtype_name}") |
| llm_load_kwargs['torch_dtype'] = dtype_map[normalized_dtype] |
| if attn_impl: |
| llm_load_kwargs['attn_implementation'] = attn_impl |
| |
| if attn_impl in ('flash_attention_2', 'sdpa') and 'torch_dtype' not in llm_load_kwargs: |
| llm_load_kwargs['torch_dtype'] = torch.bfloat16 |
| if accelerator.is_main_process: |
| print(f"⚡ LLM attention implementation: {attn_impl}") |
| llm_load_kwargs['low_cpu_mem_usage'] = True |
| self.llm_model = AutoModelForCausalLM.from_pretrained( |
| self.config.llm_model_path, |
| **llm_load_kwargs, |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path) |
| if accelerator.is_main_process: |
| print(f"✅ Loaded LLM model from: {self.config.llm_model_path}") |
| except Exception as e: |
| if accelerator.is_main_process: |
| print(f"❌ Failed to load LLM model from {self.config.llm_model_path}: {e}") |
| raise e |
| |
| if self.llm_model is not None: |
| self.llm_model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
| self._configure_lora() |
| |
| |
| ts_config.llm_d_model = self.llm_model.config.hidden_size |
| |
| |
| self.ts_encoder = Model(ts_config) |
| |
| |
| load_path = getattr(ts_config, 'load_ts_encoder', None) |
| if load_path and os.path.exists(load_path): |
| if accelerator.is_main_process: |
| from utils.log_util import adaptive_print |
| adaptive_print(f"📥 Loading pre-trained TimeSeries Encoder from: {load_path}") |
| |
| try: |
| if load_path.endswith('.safetensors'): |
| from safetensors.torch import load_file |
| ts_state_dict = load_file(load_path) |
| else: |
| ts_state_dict = torch.load(load_path, map_location='cpu') |
| |
| |
| new_state_dict = {} |
| for k, v in ts_state_dict.items(): |
| if k.startswith('model.'): |
| new_state_dict[k[6:]] = v |
| else: |
| new_state_dict[k] = v |
| |
| msg = self.ts_encoder.load_state_dict(new_state_dict, strict=False) |
| if accelerator.is_main_process: |
| adaptive_print(f"✅ TS Encoder weights loaded. Missing: {len(msg.missing_keys)}, Unexpected: {len(msg.unexpected_keys)}") |
| except Exception as e: |
| if accelerator.is_main_process: |
| adaptive_print(f"❌ Failed to load TS Encoder weights: {e}") |
| elif load_path: |
| if accelerator.is_main_process: |
| from utils.log_util import adaptive_print |
| adaptive_print(f"⚠️ Warning: TS Encoder load path '{load_path}' does not exist. Using random initialization.") |
|
|
| adapter_type = getattr(ts_config, 'adapter_type', 'itformer').lower() |
| if adapter_type == 'itformer': |
| self.itformer = ITFormer(ts_config) |
| elif adapter_type == 'qformer': |
| self.itformer = QFormerAdapter(ts_config) |
| else: |
| raise ValueError(f"Unsupported adapter_type: {adapter_type}") |
| if accelerator.is_main_process: |
| print(f"🔌 Using adapter: {adapter_type}") |
| |
| |
| self.ts_project = nn.Linear(ts_config.d_model, ts_config.it_d_model) |
| self.query_project = nn.Linear(ts_config.llm_d_model, ts_config.it_d_model) |
| self.fusion_project = nn.Linear(ts_config.it_d_model, ts_config.llm_d_model) |
| |
| |
| self._freeze_layers() |
|
|
| def _configure_lora(self): |
| if not getattr(self.config, "use_lora", False): |
| return |
|
|
| try: |
| from peft import LoraConfig, TaskType, get_peft_model |
| except ImportError as exc: |
| raise RuntimeError("PEFT is required when use_lora=True.") from exc |
|
|
| target_modules = getattr(self.config, "lora_target_modules", None) |
| if isinstance(target_modules, str): |
| target_modules = [ |
| item.strip() for item in target_modules.split(",") if item.strip() |
| ] |
| if not target_modules: |
| raise ValueError("lora_target_modules must not be empty.") |
|
|
| lora_config = LoraConfig( |
| r=int(self.config.lora_r), |
| lora_alpha=int(self.config.lora_alpha), |
| lora_dropout=float(self.config.lora_dropout), |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| target_modules=list(target_modules), |
| ) |
| self.llm_model = get_peft_model(self.llm_model, lora_config) |
| self.llm_model.config.use_cache = False |
|
|
| if getattr(self.config, "gradient_checkpointing", False): |
| self.llm_model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| self.llm_model.enable_input_require_grads() |
|
|
| if accelerator.is_main_process: |
| self.llm_model.print_trainable_parameters() |
|
|
| def _freeze_layers(self): |
| """根据配置冻结特定层,保留中间件的可训练性。""" |
| |
| |
| if self.llm_model is not None: |
| use_lora = bool(getattr(self.config, "use_lora", False)) |
| for name, param in self.llm_model.named_parameters(): |
| param.requires_grad = use_lora and "lora_" in name |
|
|
| |
| if self.config.freeze_ts_model: |
| for param in self.ts_encoder.parameters(): |
| param.requires_grad = False |
| else: |
| pass |
|
|
| |
| |
| |
|
|
| def _setup_inference_mode(self): |
| """Set inference mode, freeze all parameters.""" |
| for param in self.parameters(): |
| param.requires_grad = False |
| self.eval() |
| if accelerator.is_main_process: |
| print('🧊 Model set to inference mode - all parameters frozen') |
|
|
| def eval(self): |
| """Set model to evaluation mode.""" |
| super().eval() |
| if self.llm_model is not None: |
| self.llm_model.eval() |
| if self.ts_encoder is not None: |
| self.ts_encoder.eval() |
| if self.itformer is not None: |
| self.itformer.eval() |
| if self.ts_project is not None: |
| self.ts_project.eval() |
| if self.query_project is not None: |
| self.query_project.eval() |
| if self.fusion_project is not None: |
| self.fusion_project.eval() |
|
|
| def prepare_inputs_for_generation(self, input_ids, query_ids, past_key_values=None, attention_mask=None, **kwargs): |
| """Prepare inputs for text generation. |
| |
| Args: |
| input_ids: Input token IDs |
| query_ids: Query token IDs |
| past_key_values: Past key values for caching |
| attention_mask: Attention mask |
| **kwargs: Additional arguments |
| |
| Returns: |
| dict: Prepared inputs for generation |
| """ |
| ts_values = kwargs.get("ts_values", None) |
| stage = kwargs.get("stage", None) |
| |
| if input_ids is None or input_ids.numel() == 0 or ts_values is None or ts_values.numel() == 0: |
| return { |
| "inputs_embeds": torch.empty(0, self.llm_model.config.hidden_size, device=input_ids.device), |
| "attention_mask": attention_mask, |
| } |
| |
| device = next(self.llm_model.parameters()).device |
| input_ids = input_ids.to(device) |
| ts_values = ts_values.to(device) |
| attention_mask = attention_mask.to(device) |
| |
| if ts_values is None: |
| raise ValueError("`ts_values` must be provided for generation.") |
| |
| |
| query_embeds = self.llm_model.get_input_embeddings()(query_ids) |
| ts_embeds = self.ts_encoder(ts_values).logits |
| ts_embeds = self.ts_project(ts_embeds) |
| query_embeds_f = self.query_project(query_embeds) |
| it_embeds = self.itformer(query_embeds_f, ts_embeds, stage) |
| it_embeds = self.fusion_project(it_embeds) |
| |
| |
| inputs_embeds = self.llm_model.get_input_embeddings()(input_ids) |
| inputs_embeds = self.merge_input_ids_with_ts_features(it_embeds, inputs_embeds, input_ids) |
|
|
| return { |
| "inputs_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| } |
|
|
| def forward(self, input_ids=None, query_ids=None, |
| ts_values=None, inputs_embeds=None, stage=None, index=None, |
| attention_mask=None, past_key_values=None, labels=None, **kwargs): |
| """Forward pass of the model. |
| |
| Args: |
| input_ids: Input token IDs |
| query_ids: Query token IDs |
| ts_values: Time series values |
| inputs_embeds: Pre-computed input embeddings |
| stage: Processing stage |
| index: Sample index |
| attention_mask: Attention mask |
| past_key_values: Past key values for caching |
| labels: Ground truth labels for loss calculation |
| **kwargs: Additional arguments |
| |
| Returns: |
| CausalLMOutputWithPast: Model output |
| """ |
| if inputs_embeds is None: |
| |
| query_embeds = self.llm_model.get_input_embeddings()(query_ids) |
| |
| ts_embeds = self.ts_encoder(ts_values).logits |
| ts_embeds = self.ts_project(ts_embeds) |
| query_embeds_f = self.query_project(query_embeds) |
| it_embeds = self.itformer(query_embeds_f, ts_embeds, stage) |
| it_embeds = self.fusion_project(it_embeds) |
| inputs_embeds = self.llm_model.get_input_embeddings()(input_ids) |
| inputs_embeds = self.merge_input_ids_with_ts_features(it_embeds, inputs_embeds, input_ids) |
|
|
| |
| use_cache = not self.training |
| outputs = self.llm_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| ) |
| |
| logits = outputs.logits |
| return CausalLMOutputWithPast( |
| logits=logits, |
| past_key_values=outputs.past_key_values if use_cache else None, |
| ) |
|
|
| def merge_input_ids_with_ts_features(self, ts_features, inputs_embeds, input_ids): |
| batch_size, seq_len, embed_dim = inputs_embeds.shape |
| num_tss, num_ts_patches, embed_dim_ = ts_features.shape |
| assert embed_dim == embed_dim_, "Embedding dimensions must match." |
|
|
| pad_token_id = self.tokenizer('<|image_pad|>')['input_ids'][0] |
| batch_indices, seq_indices = torch.where(input_ids == pad_token_id) |
|
|
| if len(batch_indices) != num_tss * num_ts_patches: |
| raise ValueError(f"Mismatch: found {len(batch_indices)} pad positions but got {num_tss * num_ts_patches} ts_features.") |
| ts_features_flat = ts_features.view(-1, embed_dim).to( |
| dtype=inputs_embeds.dtype, |
| device=inputs_embeds.device |
| ) |
| inputs_embeds = inputs_embeds.clone() |
| inputs_embeds[batch_indices, seq_indices] = ts_features_flat |
|
|
| return inputs_embeds |
|
|