File size: 23,209 Bytes
c8aad8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Time Language Model (TLM) for inference.
A multimodal model that combines time series data with language model for time series question answering.
"""
import os
import json
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from safetensors.torch import load_file
from models.TimeSeriesEncoder import Model
from models.ITFormer import ITFormer
from models.QFormerAdapter import QFormerAdapter
from accelerate import Accelerator

accelerator = Accelerator()

LORA_STATE_MARKERS = (
    ".lora_A.",
    ".lora_B.",
    ".lora_embedding_A.",
    ".lora_embedding_B.",
)


class TLMConfig(PretrainedConfig):
    """Configuration class for Time Language Model."""
    model_type = "vlm_model"
    
    def __init__(self, llm_model_path='LLM/Qwen2.5-0.5B-Instruct',
                 freeze_ts_model=True,
                 ts_pad_num=25,
                 llm_attn_implementation=None,
                 llm_torch_dtype=None,
                 use_lora=False,
                 lora_r=16,
                 lora_alpha=32,
                 lora_dropout=0.05,
                 lora_target_modules=None,
                 gradient_checkpointing=False,
                 **kwargs):
        """Initialize TLM configuration.
        
        Args:
            llm_model_path: Path to the language model
            freeze_ts_model: Whether to freeze time series model parameters
            ts_pad_num: Number of time series padding tokens
            **kwargs: Additional configuration parameters
        """
        self.llm_model_path = llm_model_path
        self.freeze_ts_model = freeze_ts_model
        self.ts_pad_num = ts_pad_num
        self.llm_attn_implementation = llm_attn_implementation
        self.llm_torch_dtype = llm_torch_dtype
        self.use_lora = use_lora
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.lora_target_modules = lora_target_modules or [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ]
        self.gradient_checkpointing = gradient_checkpointing
        super().__init__(**kwargs)


class TLM(PreTrainedModel, GenerationMixin):
    """Time Language Model for inference."""
    config_class = TLMConfig

    def state_dict(self, *args, **kwargs):
        """Return checkpoint weights without the frozen base LLM.

        The frozen base Qwen weights are reloaded from config.llm_model_path.
        Keep only the trainable LoRA matrices under llm_model.*.
        """
        state_dict = super().state_dict(*args, **kwargs)
        return {
            key: value
            for key, value in state_dict.items()
            if not key.startswith("llm_model.")
            or any(marker in key for marker in LORA_STATE_MARKERS)
        }
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, config=None, **kwargs):
        """Load model from pretrained checkpoint.
        
        Args:
            pretrained_model_name_or_path: Path to the checkpoint
            config: Model configuration
            **kwargs: Additional arguments, including ts_config
            
        Returns:
            TLM: Loaded model instance
        """
        if not os.path.exists(pretrained_model_name_or_path):
            raise ValueError(f"Checkpoint path does not exist: {pretrained_model_name_or_path}")

        # Load config.json
        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                config_dict = json.load(f)
            if config is None:
                config = TLMConfig(**config_dict)
        else:
            if config is None:
                config = TLMConfig()

        # Create model instance with potential ts_config from kwargs
        model = cls(config, **kwargs)

        # Load model weights
        model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
        if not os.path.exists(model_path):
            model_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")

        state_dict = None
        # 1. Try normal files
        if os.path.exists(model_path):
            if accelerator.is_main_process:
                print(f"Loading model weights from: {model_path}")
            if model_path.endswith('.safetensors'):
                state_dict = load_file(model_path)
            else:
                state_dict = torch.load(model_path, map_location='cpu')
        else:
            # 2. Try split safetensors in the same directory
            all_files = os.listdir(pretrained_model_name_or_path)
            safetensors_files = [f for f in all_files if f.startswith('model-') and f.endswith('.safetensors')]
            safetensors_files.sort()  # Ensure order
            if safetensors_files:
                if accelerator.is_main_process:
                    print(f"Loading split safetensors from: {pretrained_model_name_or_path}")
                state_dict = {}
                for fname in safetensors_files:
                    fpath = os.path.join(pretrained_model_name_or_path, fname)
                    part = load_file(fpath)
                    state_dict.update(part)
                if accelerator.is_main_process:
                    print(f"Successfully loaded {len(safetensors_files)} split safetensors files.")
        if state_dict is not None:
            # Ignore frozen base-LLM weights but retain LoRA matrices.
            ignored_llm_weights = {}
            other_weights = {}
            for k, v in state_dict.items():
                is_lora_weight = any(marker in k for marker in LORA_STATE_MARKERS)
                if k.startswith('llm_model.') and not is_lora_weight:
                    ignored_llm_weights[k] = v
                else:
                    other_weights[k] = v
            if accelerator.is_main_process:
                lora_count = sum(
                    any(marker in key for marker in LORA_STATE_MARKERS)
                    for key in other_weights
                )
                print(f"Found {len(ignored_llm_weights)} frozen LLM weights (will be ignored)")
                print(f"Found {lora_count} LoRA tensors")
                print(f"Found {len(other_weights) - lora_count} non-LLM tensors")
            checkpoint_has_lora = any(
                any(marker in key for marker in LORA_STATE_MARKERS)
                for key in other_weights
            )
            model_has_lora = any(
                "lora_" in name for name, _ in model.llm_model.named_parameters()
            )
            if checkpoint_has_lora and not model_has_lora:
                raise ValueError(
                    "The checkpoint contains LoRA matrices, but the model was "
                    "constructed with use_lora=False."
                )
            if getattr(model.config, "use_lora", False) and not checkpoint_has_lora:
                raise ValueError(
                    "The model was constructed with use_lora=True, but the "
                    "checkpoint does not contain LoRA matrices."
                )
            missing_keys, unexpected_keys = model.load_state_dict(other_weights, strict=False)
            # Filter out LLM-related missing keys since we're not loading LLM weights
            llm_missing_keys = [
                k
                for k in missing_keys
                if k.startswith('llm_model.')
                and not any(marker in k for marker in LORA_STATE_MARKERS)
            ]
            non_llm_missing_keys = [k for k in missing_keys if not k.startswith('llm_model.')]
            missing_lora_keys = [
                k
                for k in missing_keys
                if any(marker in k for marker in LORA_STATE_MARKERS)
            ]
            if llm_missing_keys and accelerator.is_main_process:
                print(f"LLM missing keys (ignored): {len(llm_missing_keys)} keys")
            if missing_lora_keys:
                raise ValueError(f"Missing LoRA checkpoint keys: {missing_lora_keys}")
            if non_llm_missing_keys and accelerator.is_main_process:
                print(f"Non-LLM missing keys: {non_llm_missing_keys}")
            if unexpected_keys and accelerator.is_main_process:
                print(f"Unexpected keys: {unexpected_keys}")
        else:
            if accelerator.is_main_process:
                print(f"Warning: No model weights found at {model_path} or in split safetensors.")

        return model

    def __init__(self, config, ts_config=None):
        """Initialize TLM model.
        
        Args:
            config: TLM configuration
            ts_config: Optional time series configuration (args)
        """
        super().__init__(config)
        self.config = config
        
        if ts_config is None:
            # Create default ts_config if not provided
            class DefaultTSConfig:
                def __init__(self):
                    self.model = 'TimeSeriesEncoder'
                    self.d_model = 512
                    self.n_heads = 8
                    self.e_layers = 4
                    self.patch_len = 60
                    self.stride = 60
                    self.input_len = 600
                    self.dropout = 0.1
                    self.it_d_model = 896
                    self.it_n_heads = 16
                    self.it_layers = 2
                    self.it_dropout = 0.1
                    self.prefix_num = 25
                    self.adapter_type = 'itformer'
            ts_config = DefaultTSConfig()
        
        self.ts_config = ts_config
        
        # 统一属性名对齐逻辑:确保 ts_pad_num 和 prefix_num 存在且一致
        if hasattr(self.ts_config, 'ts_pad_num') and not hasattr(self.ts_config, 'prefix_num'):
            setattr(self.ts_config, 'prefix_num', self.ts_config.ts_pad_num)
        elif hasattr(self.ts_config, 'prefix_num') and not hasattr(self.ts_config, 'ts_pad_num'):
            setattr(self.ts_config, 'ts_pad_num', self.ts_config.prefix_num)
        
        # Initialize LLM model from external path
        try:
            llm_load_kwargs = {}
            attn_impl = getattr(self.config, 'llm_attn_implementation', None)
            dtype_name = getattr(self.config, 'llm_torch_dtype', None)
            dtype_map = {
                "float16": torch.float16,
                "fp16": torch.float16,
                "bfloat16": torch.bfloat16,
                "bf16": torch.bfloat16,
                "float32": torch.float32,
                "fp32": torch.float32,
            }
            if dtype_name:
                normalized_dtype = str(dtype_name).lower()
                if normalized_dtype not in dtype_map:
                    raise ValueError(f"Unsupported llm_torch_dtype: {dtype_name}")
                llm_load_kwargs['torch_dtype'] = dtype_map[normalized_dtype]
            if attn_impl:
                llm_load_kwargs['attn_implementation'] = attn_impl
                # flash_attention_2 / sdpa require fp16/bf16 weights; fp32 errors out.
                if attn_impl in ('flash_attention_2', 'sdpa') and 'torch_dtype' not in llm_load_kwargs:
                    llm_load_kwargs['torch_dtype'] = torch.bfloat16
                if accelerator.is_main_process:
                    print(f"⚡ LLM attention implementation: {attn_impl}")
            llm_load_kwargs['low_cpu_mem_usage'] = True
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                self.config.llm_model_path,
                **llm_load_kwargs,
            )
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path)
            if accelerator.is_main_process:
                print(f"✅ Loaded LLM model from: {self.config.llm_model_path}")
        except Exception as e:
            if accelerator.is_main_process:
                print(f"❌ Failed to load LLM model from {self.config.llm_model_path}: {e}")
            raise e
        
        if self.llm_model is not None:
            self.llm_model.config.pad_token_id = self.tokenizer.pad_token_id

        self._configure_lora()
        
        # Set LLM hidden layer dimension
        ts_config.llm_d_model = self.llm_model.config.hidden_size
        
        # Initialize components
        self.ts_encoder = Model(ts_config)
        
        # 加载预训练的 TS Encoder 权重
        load_path = getattr(ts_config, 'load_ts_encoder', None)
        if load_path and os.path.exists(load_path):
            if accelerator.is_main_process:
                from utils.log_util import adaptive_print
                adaptive_print(f"📥 Loading pre-trained TimeSeries Encoder from: {load_path}")
            
            try:
                if load_path.endswith('.safetensors'):
                    from safetensors.torch import load_file
                    ts_state_dict = load_file(load_path)
                else:
                    ts_state_dict = torch.load(load_path, map_location='cpu')
                
                # 兼容性处理:如果权重包含前缀,进行移除
                new_state_dict = {}
                for k, v in ts_state_dict.items():
                    if k.startswith('model.'):
                        new_state_dict[k[6:]] = v
                    else:
                        new_state_dict[k] = v
                
                msg = self.ts_encoder.load_state_dict(new_state_dict, strict=False)
                if accelerator.is_main_process:
                    adaptive_print(f"✅ TS Encoder weights loaded. Missing: {len(msg.missing_keys)}, Unexpected: {len(msg.unexpected_keys)}")
            except Exception as e:
                if accelerator.is_main_process:
                    adaptive_print(f"❌ Failed to load TS Encoder weights: {e}")
        elif load_path:
            if accelerator.is_main_process:
                from utils.log_util import adaptive_print
                adaptive_print(f"⚠️ Warning: TS Encoder load path '{load_path}' does not exist. Using random initialization.")

        adapter_type = getattr(ts_config, 'adapter_type', 'itformer').lower()
        if adapter_type == 'itformer':
            self.itformer = ITFormer(ts_config)
        elif adapter_type == 'qformer':
            self.itformer = QFormerAdapter(ts_config)
        else:
            raise ValueError(f"Unsupported adapter_type: {adapter_type}")
        if accelerator.is_main_process:
            print(f"🔌 Using adapter: {adapter_type}")
        
        # Projection layers
        self.ts_project = nn.Linear(ts_config.d_model, ts_config.it_d_model)
        self.query_project = nn.Linear(ts_config.llm_d_model, ts_config.it_d_model)
        self.fusion_project = nn.Linear(ts_config.it_d_model, ts_config.llm_d_model)
        
        # 根据配置冻结参数
        self._freeze_layers()

    def _configure_lora(self):
        if not getattr(self.config, "use_lora", False):
            return

        try:
            from peft import LoraConfig, TaskType, get_peft_model
        except ImportError as exc:
            raise RuntimeError("PEFT is required when use_lora=True.") from exc

        target_modules = getattr(self.config, "lora_target_modules", None)
        if isinstance(target_modules, str):
            target_modules = [
                item.strip() for item in target_modules.split(",") if item.strip()
            ]
        if not target_modules:
            raise ValueError("lora_target_modules must not be empty.")

        lora_config = LoraConfig(
            r=int(self.config.lora_r),
            lora_alpha=int(self.config.lora_alpha),
            lora_dropout=float(self.config.lora_dropout),
            bias="none",
            task_type=TaskType.CAUSAL_LM,
            target_modules=list(target_modules),
        )
        self.llm_model = get_peft_model(self.llm_model, lora_config)
        self.llm_model.config.use_cache = False

        if getattr(self.config, "gradient_checkpointing", False):
            self.llm_model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": False}
            )
            self.llm_model.enable_input_require_grads()

        if accelerator.is_main_process:
            self.llm_model.print_trainable_parameters()

    def _freeze_layers(self):
        """根据配置冻结特定层,保留中间件的可训练性。"""
        # Freeze the base LLM. PEFT has already marked only LoRA matrices as
        # trainable, so preserve those flags when LoRA is enabled.
        if self.llm_model is not None:
            use_lora = bool(getattr(self.config, "use_lora", False))
            for name, param in self.llm_model.named_parameters():
                param.requires_grad = use_lora and "lora_" in name

        # 2. 根据配置冻结 TS Encoder
        if self.config.freeze_ts_model:
            for param in self.ts_encoder.parameters():
                param.requires_grad = False
        else:
            pass

        # 3. 确保中间件是可训练的 (ITFormer 和 Projections)
        # 这些层默认 requires_grad=True,所以不需要额外操作,
        # 除非之前调用了 _setup_inference_mode()

    def _setup_inference_mode(self):
        """Set inference mode, freeze all parameters."""
        for param in self.parameters():
            param.requires_grad = False
        self.eval()
        if accelerator.is_main_process:
            print('🧊 Model set to inference mode - all parameters frozen')

    def eval(self):
        """Set model to evaluation mode."""
        super().eval()
        if self.llm_model is not None:
            self.llm_model.eval()
        if self.ts_encoder is not None:
            self.ts_encoder.eval()
        if self.itformer is not None:
            self.itformer.eval()
        if self.ts_project is not None:
            self.ts_project.eval()
        if self.query_project is not None:
            self.query_project.eval()
        if self.fusion_project is not None:
            self.fusion_project.eval()

    def prepare_inputs_for_generation(self, input_ids, query_ids, past_key_values=None, attention_mask=None, **kwargs):
        """Prepare inputs for text generation.
        
        Args:
            input_ids: Input token IDs
            query_ids: Query token IDs
            past_key_values: Past key values for caching
            attention_mask: Attention mask
            **kwargs: Additional arguments
            
        Returns:
            dict: Prepared inputs for generation
        """
        ts_values = kwargs.get("ts_values", None)
        stage = kwargs.get("stage", None)
        
        if input_ids is None or input_ids.numel() == 0 or ts_values is None or ts_values.numel() == 0:
            return {
                "inputs_embeds": torch.empty(0, self.llm_model.config.hidden_size, device=input_ids.device),
                "attention_mask": attention_mask,
            }
        
        device = next(self.llm_model.parameters()).device
        input_ids = input_ids.to(device)
        ts_values = ts_values.to(device)             
        attention_mask = attention_mask.to(device) 
        
        if ts_values is None:
            raise ValueError("`ts_values` must be provided for generation.")
        
        # Process time series and query
        query_embeds = self.llm_model.get_input_embeddings()(query_ids)
        ts_embeds = self.ts_encoder(ts_values).logits
        ts_embeds = self.ts_project(ts_embeds)
        query_embeds_f = self.query_project(query_embeds)
        it_embeds = self.itformer(query_embeds_f, ts_embeds, stage)
        it_embeds = self.fusion_project(it_embeds)
        
        # Generate inputs_embeds
        inputs_embeds = self.llm_model.get_input_embeddings()(input_ids)
        inputs_embeds = self.merge_input_ids_with_ts_features(it_embeds, inputs_embeds, input_ids)

        return {
            "inputs_embeds": inputs_embeds,
            "attention_mask": attention_mask,
        }

    def forward(self, input_ids=None, query_ids=None, 
                ts_values=None, inputs_embeds=None, stage=None, index=None,
                attention_mask=None, past_key_values=None, labels=None, **kwargs):
        """Forward pass of the model.
        
        Args:
            input_ids: Input token IDs
            query_ids: Query token IDs
            ts_values: Time series values
            inputs_embeds: Pre-computed input embeddings
            stage: Processing stage
            index: Sample index
            attention_mask: Attention mask
            past_key_values: Past key values for caching
            labels: Ground truth labels for loss calculation
            **kwargs: Additional arguments
            
        Returns:
            CausalLMOutputWithPast: Model output
        """
        if inputs_embeds is None:
            # Get query embedding
            query_embeds = self.llm_model.get_input_embeddings()(query_ids)
            # Time series encoding
            ts_embeds = self.ts_encoder(ts_values).logits
            ts_embeds = self.ts_project(ts_embeds)
            query_embeds_f = self.query_project(query_embeds)
            it_embeds = self.itformer(query_embeds_f, ts_embeds, stage)
            it_embeds = self.fusion_project(it_embeds)
            inputs_embeds = self.llm_model.get_input_embeddings()(input_ids)
            inputs_embeds = self.merge_input_ids_with_ts_features(it_embeds, inputs_embeds, input_ids)

        # Forward through LLM
        use_cache = not self.training
        outputs = self.llm_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            use_cache=use_cache,
        )
        
        logits = outputs.logits
        return CausalLMOutputWithPast(
            logits=logits,
            past_key_values=outputs.past_key_values if use_cache else None,
        )

    def merge_input_ids_with_ts_features(self, ts_features, inputs_embeds, input_ids):
        batch_size, seq_len, embed_dim = inputs_embeds.shape
        num_tss, num_ts_patches, embed_dim_ = ts_features.shape
        assert embed_dim == embed_dim_, "Embedding dimensions must match."

        pad_token_id = self.tokenizer('<|image_pad|>')['input_ids'][0]
        batch_indices, seq_indices = torch.where(input_ids == pad_token_id)

        if len(batch_indices) != num_tss * num_ts_patches:
            raise ValueError(f"Mismatch: found {len(batch_indices)} pad positions but got {num_tss * num_ts_patches} ts_features.")
        ts_features_flat = ts_features.view(-1, embed_dim).to(
            dtype=inputs_embeds.dtype, 
            device=inputs_embeds.device
        )
        inputs_embeds = inputs_embeds.clone()
        inputs_embeds[batch_indices, seq_indices] = ts_features_flat

        return inputs_embeds