File size: 23,014 Bytes
46b244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import re

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen


if TYPE_CHECKING:
    from ..mm_plugin import AudioInput, ImageInput, VideoInput


logger = logging.get_logger(__name__)


@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
    def _encode_data_example(
        self,
        prompt: list[dict[str, str]],
        response: list[dict[str, str]],
        system: Optional[str],
        tools: Optional[str],
        images: list["ImageInput"],
        videos: list["VideoInput"],
        audios: list["AudioInput"],
    ) -> tuple[list[int], list[int]]:
        messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
        input_ids, labels = self.template.mm_plugin.process_token_ids(
            [], [], images, videos, audios, self.tokenizer, self.processor
        )
        encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
        total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
        
        # 添加详细日志记录
        import os
        from datetime import datetime
        
        def log_debug(msg):
            """简单的调试日志函数"""
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
            log_entry = f"{timestamp} | INFO | {msg}\n"
            
            # 写入日志文件
            log_file = "/home/ziqiang/LLaMA-Factory/sharegpt_pair_debug.log"
            try:
                with open(log_file, "a", encoding="utf-8") as f:
                    f.write(log_entry)
                    f.flush()  # 立即刷新到文件
            except:
                pass  # 忽略写文件错误
                
            # 只写入日志文件,不输出到控制台
        
        log_debug("\n" + "🔧 " + "=" * 78)
        log_debug("🔧 ShareGPT数据处理器 - _encode_data_example开始")
        log_debug("🔧 " + "=" * 78)
        
        log_debug(f"📊 开始处理数据样本")
        log_debug(f"📊 原始conversations长度: {len(prompt + response)} 条消息")
        log_debug(f"📊 编码后的pairs数量: {len(encoded_pairs)}")
        log_debug(f"📊 初始total_length: {total_length}")
        log_debug(f"📊 cutoff_len: {self.data_args.cutoff_len}")
        log_debug(f"📊 mask_history: {self.data_args.mask_history}")
        log_debug(f"📊 train_on_prompt: {self.data_args.train_on_prompt}")
        
        if self.data_args.mask_history:
            encoded_pairs = encoded_pairs[::-1]  # high priority for last turns
            log_debug(f"🔄 启用mask_history,pairs顺序已反转")

        log_debug("\n" + "📋 " + "-" * 76)
        log_debug("📋 开始处理每个Pair")
        log_debug("📋 " + "-" * 76)

        for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
            original_source_len = len(source_ids)
            original_target_len = len(target_ids)
            remaining_budget = self.data_args.cutoff_len - total_length
            
            log_debug(f"\n🔄 === 处理Pair {turn_idx + 1} ===")
            log_debug(f"📏 原始长度: source={original_source_len}, target={original_target_len}")
            log_debug(f"💰 剩余预算: {remaining_budget}")
            
            if total_length >= self.data_args.cutoff_len:
                log_debug(f"❌ 预算耗尽,丢弃剩余pairs")
                break

            source_len, target_len = infer_seqlen(
                original_source_len, original_target_len, remaining_budget
            )
            
            log_debug(f"✂️ 截断后长度: source={original_source_len}->{source_len}, target={original_target_len}->{target_len}")
            
            if source_len < original_source_len:
                log_debug(f"⚠️ source被截断: {original_source_len - source_len} tokens")
            if target_len < original_target_len:
                log_debug(f"⚠️ target被截断: {original_target_len - target_len} tokens")
            
            source_ids = source_ids[:source_len]
            target_ids = target_ids[:target_len]
            total_length += source_len + target_len
            
            log_debug(f"📈 当前累计长度: {total_length}/{self.data_args.cutoff_len} ({total_length/self.data_args.cutoff_len*100:.1f}%)")

            # 生成标签
            if self.data_args.train_on_prompt:
                source_label = source_ids
                log_debug(f"🏷️ train_on_prompt=True, source_label使用原始tokens")
                log_debug(f"   📊 source_label长度: {len(source_label)} tokens")
                if len(source_label) > 0:
                    source_label_preview = self.tokenizer.decode(source_label[:min(20, len(source_label))], skip_special_tokens=False)
                    source_label_clean = source_label_preview.replace(chr(10), '\\n')
                    log_debug(f"   📄 source_label预览: {source_label_clean[:100]}...")
            elif self.template.efficient_eos:
                source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
                log_debug(f"🏷️ efficient_eos=True, source_label=[eos_token, {source_len-1}*IGNORE_INDEX]")
                log_debug(f"   📊 source_label长度: {len(source_label)} tokens")
                log_debug(f"   🔍 eos_token_id: {self.tokenizer.eos_token_id}")
                log_debug(f"   🔍 IGNORE_INDEX: {IGNORE_INDEX}")
            else:
                source_label = [IGNORE_INDEX] * source_len
                log_debug(f"🏷️ source_label={source_len}*IGNORE_INDEX")
                log_debug(f"   📊 source_label长度: {len(source_label)} tokens")
                log_debug(f"   🔍 IGNORE_INDEX: {IGNORE_INDEX}")

            if self.data_args.mask_history and turn_idx != 0:  # train on the last turn only
                target_label = [IGNORE_INDEX] * target_len
                log_debug(f"🏷️ mask_history=True且turn_idx!=0, target_label={target_len}*IGNORE_INDEX")
                log_debug(f"   📊 target_label长度: {len(target_label)} tokens")
                log_debug(f"   🔍 IGNORE_INDEX: {IGNORE_INDEX}")
            else:
                target_label = target_ids
                log_debug(f"🏷️ target_label使用原始tokens")
                log_debug(f"   📊 target_label长度: {len(target_label)} tokens")
                if len(target_label) > 0:
                    target_label_preview = self.tokenizer.decode(target_label[:min(20, len(target_label))], skip_special_tokens=False)
                    target_label_clean = target_label_preview.replace(chr(10), '\\n')
                    log_debug(f"   📄 target_label预览: {target_label_clean[:100]}...")

            if self.data_args.mask_history:  # reversed sequences
                input_ids = source_ids + target_ids + input_ids
                labels = source_label + target_label + labels
                log_debug(f"🔄 mask_history=True, 序列已反转拼接")
                log_debug(f"   📊 拼接后input_ids长度: {len(input_ids)}")
                log_debug(f"   📊 拼接后labels长度: {len(labels)}")
            else:
                input_ids += source_ids + target_ids
                labels += source_label + target_label
                log_debug(f"➡️ 正常顺序拼接")
                log_debug(f"   📊 拼接后input_ids长度: {len(input_ids)}")
                log_debug(f"   📊 拼接后labels长度: {len(labels)}")
            
            # 显示当前labels中的有效token统计
            valid_labels_count = sum(1 for label in labels if label != IGNORE_INDEX)
            total_labels_count = len(labels)
            valid_percentage = (valid_labels_count / total_labels_count * 100) if total_labels_count > 0 else 0
            log_debug(f"   📊 当前有效labels: {valid_labels_count}/{total_labels_count} ({valid_percentage:.1f}%)")
            
            # 显示labels的详细组成
            if len(labels) > 0:
                unique_labels = set(labels)
                label_stats = {}
                for label in unique_labels:
                    count = labels.count(label)
                    if label == IGNORE_INDEX:
                        label_stats[f"IGNORE_INDEX({label})"] = count
                    elif label == self.tokenizer.eos_token_id:
                        label_stats[f"EOS_TOKEN({label})"] = count
                    else:
                        label_stats[f"TOKEN_{label}"] = count
                
                log_debug(f"   📊 Labels组成: {dict(list(label_stats.items())[:5])}")  # 只显示前5个

        if self.template.efficient_eos:
            input_ids += [self.tokenizer.eos_token_id]
            labels += [self.tokenizer.eos_token_id]
            total_length += 1
            log_debug(f"🔚 添加eos_token, total_length={total_length}")

        log_debug("\n" + "🎯 " + "=" * 76)
        log_debug("🎯 最终结果统计")
        log_debug("🎯 " + "=" * 76)
        log_debug(f"📊 最终input_ids长度: {len(input_ids)}")
        log_debug(f"📊 最终labels长度: {len(labels)}")
        log_debug(f"📊 最终total_length: {total_length}")
        log_debug(f"📊 使用率: {total_length}/{self.data_args.cutoff_len} ({total_length/self.data_args.cutoff_len*100:.1f}%)")
        
        # 统计有效标签数量
        valid_labels = [l for l in labels if l != IGNORE_INDEX]
        log_debug(f"📊 有效标签数量: {len(valid_labels)}/{len(labels)} ({len(valid_labels)/len(labels)*100:.1f}%)")
        
        log_debug("🔧 " + "=" * 78)
        log_debug("🔧 _encode_data_example处理完成")
        log_debug("🔧 " + "=" * 78)

        return input_ids, labels

    def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
        # for multiturn examples, we only mask the prompt part in each prompt-response pair.
        
        # 添加日志记录
        import os
        from datetime import datetime
        
        def log_debug(msg):
            """简单的调试日志函数"""
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
            log_entry = f"{timestamp} | INFO | {msg}\n"
            
            # 写入日志文件
            log_file = "/home/ziqiang/LLaMA-Factory/sharegpt_pair_debug.log"
            try:
                with open(log_file, "a", encoding="utf-8") as f:
                    f.write(log_entry)
                    f.flush()  # 立即刷新到文件
            except:
                pass  # 忽略写文件错误
        
        log_debug("\n" + "🚀 " + "=" * 78)
        log_debug("🚀 SupervisedDatasetProcessor.preprocess_dataset 开始")
        log_debug("🚀 " + "=" * 78)
        log_debug(f"📊 处理样本数量: {len(examples['_prompt'])}")
        
        model_inputs = defaultdict(list)
        for i in range(len(examples["_prompt"])):
            log_debug(f"\n🔄 处理样本 {i+1}/{len(examples['_prompt'])}")
            
            if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
                log_debug(f"❌ 样本 {i+1} 格式无效,跳过")
                logger.warning_rank0(
                    "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
                )
                continue

            log_debug(f"✅ 样本 {i+1} 格式有效,开始编码")
            input_ids, labels = self._encode_data_example(
                prompt=examples["_prompt"][i],
                response=examples["_response"][i],
                system=examples["_system"][i],
                tools=examples["_tools"][i],
                images=examples["_images"][i] or [],
                videos=examples["_videos"][i] or [],
                audios=examples["_audios"][i] or [],
            )
            log_debug(f"✅ 样本 {i+1} 编码完成,input_ids长度: {len(input_ids)}, labels长度: {len(labels)}")
            
            # 应用user_id mask(已注释,user_id现在通过system prompt提供)
            # masked_labels = self._mask_user_id_tokens(input_ids, labels)
            
            model_inputs["input_ids"].append(input_ids)
            model_inputs["attention_mask"].append([1] * len(input_ids))
            model_inputs["labels"].append(labels)  # 使用原始labels,不再mask user_id
            model_inputs["images"].append(examples["_images"][i])
            model_inputs["videos"].append(examples["_videos"][i])
            model_inputs["audios"].append(examples["_audios"][i])

        log_debug("\n" + "🎯 " + "=" * 76)
        log_debug("🎯 preprocess_dataset 处理完成")
        log_debug("🎯 " + "=" * 76)
        log_debug(f"📊 最终处理样本数量: {len(model_inputs['input_ids'])}")
        log_debug("🚀 " + "=" * 78)

        return model_inputs

    def print_data_example(self, example: dict[str, list[int]]) -> None:
        valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
        print("label_ids:\n{}".format(example["labels"]))
        print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")

    def _mask_user_id_tokens(self, input_ids: list[int], labels: list[int]) -> list[int]:
        """
        在labels中mask掉user_id对应的token位置
        
        Args:
            input_ids: 输入的token ID列表
            labels: 标签列表
            
        Returns:
            list[int]: mask后的labels
        """
        masked_labels = labels.copy()
        
        # 将input_ids解码为文本
        text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
        
        # 定义user_id的模式
        user_id_patterns = [
            r'"user_id"\s*:\s*\d+',  # "user_id": 136451106
            r'"user_id"\s*:\s*"(\d+)"',  # "user_id": "136451106"
        ]
        
        # 找到user_id的位置
        user_id_positions = []
        for pattern in user_id_patterns:
            matches = list(re.finditer(pattern, text))
            for match in matches:
                start_char, end_char = match.span()
                
                # 使用更精确的方法找到token位置
                try:
                    # 获取user_id部分的文本
                    user_id_text = text[start_char:end_char]
                    
                    # 在input_ids中搜索这个文本对应的token序列
                    # 先尝试直接匹配
                    user_id_tokens = self.tokenizer.encode(user_id_text, add_special_tokens=False)
                    
                    # 在input_ids中查找这个token序列
                    for i in range(len(input_ids) - len(user_id_tokens) + 1):
                        if input_ids[i:i+len(user_id_tokens)] == user_id_tokens:
                            user_id_positions.extend(range(i, i+len(user_id_tokens)))
                            print(f"🔒 找到user_id token位置: {i}{i+len(user_id_tokens)-1}")
                            print(f"   user_id文本: {user_id_text}")
                            print(f"   user_id tokens: {user_id_tokens}")
                            break
                    
                    # 如果直接匹配失败,尝试更宽松的匹配
                    if not user_id_positions:
                        # 提取数字部分
                        numbers = re.findall(r'\d+', user_id_text)
                        for num in numbers:
                            num_tokens = self.tokenizer.encode(num, add_special_tokens=False)
                            for i in range(len(input_ids) - len(num_tokens) + 1):
                                if input_ids[i:i+len(num_tokens)] == num_tokens:
                                    user_id_positions.extend(range(i, i+len(num_tokens)))
                                    print(f"🔒 找到数字token位置: {i}{i+len(num_tokens)-1}")
                                    print(f"   数字: {num}")
                                    print(f"   数字tokens: {num_tokens}")
                                    break
                            if user_id_positions:
                                break
                                
                except Exception as e:
                    print(f"⚠️ user_id mask失败: {e}")
                    continue
        
        # 将user_id位置的labels设为IGNORE_INDEX
        for pos in user_id_positions:
            if 0 <= pos < len(masked_labels):
                masked_labels[pos] = IGNORE_INDEX
        
        # 记录mask信息
        original_trainable = sum(1 for label in labels if label != IGNORE_INDEX)
        masked_trainable = sum(1 for label in masked_labels if label != IGNORE_INDEX)
        masked_count = original_trainable - masked_trainable
        
        if masked_count > 0:
            print(f"🔒 已mask {masked_count} 个user_id相关token")
            print(f"   原始可训练token: {original_trainable}")
            print(f"   mask后可训练token: {masked_trainable}")
        else:
            print(f"⚠️ 未找到user_id token进行mask")
            print(f"   文本内容: {text[:200]}...")
        
        return masked_labels


@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
    def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
        # TODO: use `position_ids` to achieve packing
        # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
        # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
        valid_num = 0
        batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
        lengths = []
        length2indexes = defaultdict(list)
        for i in range(len(examples["_prompt"])):
            if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
                logger.warning_rank0(
                    "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
                )
                continue

            input_ids, labels = self._encode_data_example(
                prompt=examples["_prompt"][i],
                response=examples["_response"][i],
                system=examples["_system"][i],
                tools=examples["_tools"][i],
                images=examples["_images"][i] or [],
                videos=examples["_videos"][i] or [],
                audios=examples["_audios"][i] or [],
            )
            length = len(input_ids)
            if length > self.data_args.cutoff_len:
                logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
            else:
                # 应用user_id mask(已注释,user_id现在通过system prompt提供)
                # masked_labels = self._mask_user_id_tokens(input_ids, labels)
                
                lengths.append(length)
                length2indexes[length].append(valid_num)
                batch_input_ids.append(input_ids)
                batch_labels.append(labels)  # 使用原始labels,不再mask user_id
                batch_images.append(examples["_images"][i] or [])
                batch_videos.append(examples["_videos"][i] or [])
                batch_audios.append(examples["_audios"][i] or [])
                valid_num += 1

        model_inputs = defaultdict(list)
        knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
        for knapsack in knapsacks:
            packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
            packed_images, packed_videos, packed_audios = [], [], []
            for i, length in enumerate(knapsack):
                index = length2indexes[length].pop()
                packed_input_ids += batch_input_ids[index]
                packed_position_ids += list(range(len(batch_input_ids[index])))  # NOTE: pad_to_multiple_of ignore this
                packed_labels += batch_labels[index]
                packed_images += batch_images[index]
                packed_videos += batch_videos[index]
                packed_audios += batch_audios[index]
                if self.data_args.neat_packing:
                    packed_attention_masks += [i + 1] * len(batch_input_ids[index])  # start from 1
                else:
                    packed_attention_masks += [1] * len(batch_input_ids[index])

            if len(packed_input_ids) < self.data_args.cutoff_len + 1:  # avoid flash_attn drops attn mask
                pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
                packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
                packed_position_ids += [0] * pad_length
                packed_labels += [IGNORE_INDEX] * pad_length
                if self.data_args.neat_packing:
                    packed_attention_masks += [0] * pad_length
                else:
                    packed_attention_masks += [1] * pad_length  # more efficient flash_attn

            if len(packed_input_ids) != self.data_args.cutoff_len + 1:
                raise ValueError("The length of packed example should be identical to the cutoff length.")

            model_inputs["input_ids"].append(packed_input_ids)
            model_inputs["attention_mask"].append(packed_attention_masks)
            model_inputs["position_ids"].append(packed_position_ids)
            model_inputs["labels"].append(packed_labels)
            model_inputs["images"].append(packed_images or None)
            model_inputs["videos"].append(packed_videos or None)
            model_inputs["audios"].append(packed_audios or None)

        return model_inputs