nextShakespeare / llm /data /collate.py
LiManshu's picture
Add files using upload-large-folder tool
bf6be45 verified
"""
数据整理函数:padding / batch
将多个样本组合成批次
处理不同长度的序列(padding)
转换为模型需要的张量格式
"""
# 2026-01-23
import torch
def collate_fn(batch, pad_token_id=0):
"""
整理批次数据(支持 padding,但所有样本长度相同时直接堆叠)
参数:
batch: 批次数据列表,每个元素是 (input_ids, target_ids)
- input_ids: 输入序列,形状为 (seq_len,)
- target_ids: 目标序列,形状为 (seq_len,)
pad_token_id: padding token ID(默认: 0)
返回:
(input_ids_batch, target_ids_batch)
- input_ids_batch: 批次输入序列,形状为 (batch_size, max_seq_len)
- target_ids_batch: 批次目标序列,形状为 (batch_size, max_seq_len)
"""
# 分离 input_ids 和 target_ids
input_ids_list = [item[0] for item in batch]
target_ids_list = [item[1] for item in batch]
# 检查所有样本长度是否相同
input_lengths = [len(ids) for ids in input_ids_list]
target_lengths = [len(ids) for ids in target_ids_list]
all_same_length = (
len(set(input_lengths)) == 1 and
len(set(target_lengths)) == 1 and
input_lengths[0] == target_lengths[0]
)
if all_same_length:
# 所有样本长度相同,直接堆叠(高效,不需要 padding)
input_ids_batch = torch.stack(input_ids_list, dim=0) # (batch_size, seq_len)
target_ids_batch = torch.stack(target_ids_list, dim=0) # (batch_size, seq_len)
else:
# 样本长度不同,需要 padding
max_seq_len = max(max(input_lengths), max(target_lengths))
# Padding input_ids
padded_input_ids = []
for ids in input_ids_list:
pad_length = max_seq_len - len(ids)
if pad_length > 0:
padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
else:
padded = ids
padded_input_ids.append(padded)
# Padding target_ids
padded_target_ids = []
for ids in target_ids_list:
pad_length = max_seq_len - len(ids)
if pad_length > 0:
padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
else:
padded = ids
padded_target_ids.append(padded)
# 堆叠
input_ids_batch = torch.stack(padded_input_ids, dim=0) # (batch_size, max_seq_len)
target_ids_batch = torch.stack(padded_target_ids, dim=0) # (batch_size, max_seq_len)
return input_ids_batch, target_ids_batch
if __name__ == "__main__":
print("=" * 60)
print("数据整理函数测试")
print("=" * 60)
# 模拟批次数据
batch_size = 4
seq_len = 10
print("\n1. 创建模拟批次数据")
print(f" 批次大小: {batch_size}")
print(f" 序列长度: {seq_len}")
# 创建模拟数据
batch = []
for i in range(batch_size):
input_ids = torch.randint(0, 100, (seq_len,))
target_ids = torch.randint(0, 100, (seq_len,))
batch.append((input_ids, target_ids))
print(f" 样本 {i}: input_ids 形状={input_ids.shape}, target_ids 形状={target_ids.shape}")
# 测试 collate_fn
print("\n2. 测试 collate_fn")
input_ids_batch, target_ids_batch = collate_fn(batch)
print(f" 输入批次形状: {input_ids_batch.shape}")
print(f" 目标批次形状: {target_ids_batch.shape}")
print(f" 期望形状: ({batch_size}, {seq_len})")
# 验证形状
assert input_ids_batch.shape == (batch_size, seq_len), \
f"输入批次形状错误: {input_ids_batch.shape} != ({batch_size}, {seq_len})"
assert target_ids_batch.shape == (batch_size, seq_len), \
f"目标批次形状错误: {target_ids_batch.shape} != ({batch_size}, {seq_len})"
print(" 形状验证通过")
# 验证数据是否正确堆叠
print("\n3. 验证数据堆叠")
for i in range(batch_size):
input_match = torch.equal(input_ids_batch[i], batch[i][0])
target_match = torch.equal(target_ids_batch[i], batch[i][1])
print(f" 样本 {i}: input_ids 匹配={input_match}, target_ids 匹配={target_match}")
assert input_match and target_match, f"样本 {i} 数据不匹配"
print(" 数据验证通过")
# 测试不同序列长度(需要 padding)
print("\n4. 测试不同序列长度(需要 padding)")
batch_variable = [
(torch.randint(0, 100, (5,)), torch.randint(0, 100, (5,))),
(torch.randint(0, 100, (8,)), torch.randint(0, 100, (8,))),
(torch.randint(0, 100, (10,)), torch.randint(0, 100, (10,))),
]
print(" 样本长度: [5, 8, 10]")
input_batch_var, target_batch_var = collate_fn(batch_variable, pad_token_id=0)
print(f" 输入批次形状: {input_batch_var.shape}")
print(f" 目标批次形状: {target_batch_var.shape}")
print(f" 期望形状: (3, 10)")
assert input_batch_var.shape == (3, 10), \
f"输入批次形状错误: {input_batch_var.shape} != (3, 10)"
assert target_batch_var.shape == (3, 10), \
f"目标批次形状错误: {target_batch_var.shape} != (3, 10)"
# 验证 padding 是否正确
print("\n5. 验证 padding")
for i, (orig_input, orig_target) in enumerate(batch_variable):
orig_len = len(orig_input)
# 检查原始数据是否正确
assert torch.equal(input_batch_var[i, :orig_len], orig_input), \
f"样本 {i} 的 input_ids 数据不匹配"
assert torch.equal(target_batch_var[i, :orig_len], orig_target), \
f"样本 {i} 的 target_ids 数据不匹配"
# 检查 padding 是否正确(应该都是 pad_token_id)
if orig_len < 10:
assert torch.all(input_batch_var[i, orig_len:] == 0), \
f"样本 {i} 的 input_ids padding 不正确"
assert torch.all(target_batch_var[i, orig_len:] == 0), \
f"样本 {i} 的 target_ids padding 不正确"
print(f" 样本 {i}: 长度={orig_len}, padding 验证通过")
print(" ✓ Padding 验证通过")
print("\n" + "=" * 60)
print("所有测试完成!")
print("=" * 60)