| """
|
| 数据整理函数:padding / batch
|
| 将多个样本组合成批次
|
| 处理不同长度的序列(padding)
|
| 转换为模型需要的张量格式
|
| """
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
|
|
|
|
| def collate_fn(batch, pad_token_id=0):
|
| """
|
| 整理批次数据(支持 padding,但所有样本长度相同时直接堆叠)
|
|
|
| 参数:
|
| batch: 批次数据列表,每个元素是 (input_ids, target_ids)
|
| - input_ids: 输入序列,形状为 (seq_len,)
|
| - target_ids: 目标序列,形状为 (seq_len,)
|
| pad_token_id: padding token ID(默认: 0)
|
|
|
| 返回:
|
| (input_ids_batch, target_ids_batch)
|
| - input_ids_batch: 批次输入序列,形状为 (batch_size, max_seq_len)
|
| - target_ids_batch: 批次目标序列,形状为 (batch_size, max_seq_len)
|
| """
|
|
|
| input_ids_list = [item[0] for item in batch]
|
| target_ids_list = [item[1] for item in batch]
|
|
|
|
|
| input_lengths = [len(ids) for ids in input_ids_list]
|
| target_lengths = [len(ids) for ids in target_ids_list]
|
|
|
| all_same_length = (
|
| len(set(input_lengths)) == 1 and
|
| len(set(target_lengths)) == 1 and
|
| input_lengths[0] == target_lengths[0]
|
| )
|
|
|
| if all_same_length:
|
|
|
| input_ids_batch = torch.stack(input_ids_list, dim=0)
|
| target_ids_batch = torch.stack(target_ids_list, dim=0)
|
| else:
|
|
|
| max_seq_len = max(max(input_lengths), max(target_lengths))
|
|
|
|
|
| padded_input_ids = []
|
| for ids in input_ids_list:
|
| pad_length = max_seq_len - len(ids)
|
| if pad_length > 0:
|
| padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
|
| else:
|
| padded = ids
|
| padded_input_ids.append(padded)
|
|
|
|
|
| padded_target_ids = []
|
| for ids in target_ids_list:
|
| pad_length = max_seq_len - len(ids)
|
| if pad_length > 0:
|
| padded = torch.cat([ids, torch.full((pad_length,), pad_token_id, dtype=ids.dtype)])
|
| else:
|
| padded = ids
|
| padded_target_ids.append(padded)
|
|
|
|
|
| input_ids_batch = torch.stack(padded_input_ids, dim=0)
|
| target_ids_batch = torch.stack(padded_target_ids, dim=0)
|
|
|
| return input_ids_batch, target_ids_batch
|
|
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print("数据整理函数测试")
|
| print("=" * 60)
|
|
|
|
|
| batch_size = 4
|
| seq_len = 10
|
|
|
| print("\n1. 创建模拟批次数据")
|
| print(f" 批次大小: {batch_size}")
|
| print(f" 序列长度: {seq_len}")
|
|
|
|
|
| batch = []
|
| for i in range(batch_size):
|
| input_ids = torch.randint(0, 100, (seq_len,))
|
| target_ids = torch.randint(0, 100, (seq_len,))
|
| batch.append((input_ids, target_ids))
|
| print(f" 样本 {i}: input_ids 形状={input_ids.shape}, target_ids 形状={target_ids.shape}")
|
|
|
|
|
| print("\n2. 测试 collate_fn")
|
| input_ids_batch, target_ids_batch = collate_fn(batch)
|
|
|
| print(f" 输入批次形状: {input_ids_batch.shape}")
|
| print(f" 目标批次形状: {target_ids_batch.shape}")
|
| print(f" 期望形状: ({batch_size}, {seq_len})")
|
|
|
|
|
| assert input_ids_batch.shape == (batch_size, seq_len), \
|
| f"输入批次形状错误: {input_ids_batch.shape} != ({batch_size}, {seq_len})"
|
| assert target_ids_batch.shape == (batch_size, seq_len), \
|
| f"目标批次形状错误: {target_ids_batch.shape} != ({batch_size}, {seq_len})"
|
|
|
| print(" 形状验证通过")
|
|
|
|
|
| print("\n3. 验证数据堆叠")
|
| for i in range(batch_size):
|
| input_match = torch.equal(input_ids_batch[i], batch[i][0])
|
| target_match = torch.equal(target_ids_batch[i], batch[i][1])
|
| print(f" 样本 {i}: input_ids 匹配={input_match}, target_ids 匹配={target_match}")
|
| assert input_match and target_match, f"样本 {i} 数据不匹配"
|
|
|
| print(" 数据验证通过")
|
|
|
|
|
| print("\n4. 测试不同序列长度(需要 padding)")
|
| batch_variable = [
|
| (torch.randint(0, 100, (5,)), torch.randint(0, 100, (5,))),
|
| (torch.randint(0, 100, (8,)), torch.randint(0, 100, (8,))),
|
| (torch.randint(0, 100, (10,)), torch.randint(0, 100, (10,))),
|
| ]
|
|
|
| print(" 样本长度: [5, 8, 10]")
|
| input_batch_var, target_batch_var = collate_fn(batch_variable, pad_token_id=0)
|
|
|
| print(f" 输入批次形状: {input_batch_var.shape}")
|
| print(f" 目标批次形状: {target_batch_var.shape}")
|
| print(f" 期望形状: (3, 10)")
|
|
|
| assert input_batch_var.shape == (3, 10), \
|
| f"输入批次形状错误: {input_batch_var.shape} != (3, 10)"
|
| assert target_batch_var.shape == (3, 10), \
|
| f"目标批次形状错误: {target_batch_var.shape} != (3, 10)"
|
|
|
|
|
| print("\n5. 验证 padding")
|
| for i, (orig_input, orig_target) in enumerate(batch_variable):
|
| orig_len = len(orig_input)
|
|
|
| assert torch.equal(input_batch_var[i, :orig_len], orig_input), \
|
| f"样本 {i} 的 input_ids 数据不匹配"
|
| assert torch.equal(target_batch_var[i, :orig_len], orig_target), \
|
| f"样本 {i} 的 target_ids 数据不匹配"
|
|
|
| if orig_len < 10:
|
| assert torch.all(input_batch_var[i, orig_len:] == 0), \
|
| f"样本 {i} 的 input_ids padding 不正确"
|
| assert torch.all(target_batch_var[i, orig_len:] == 0), \
|
| f"样本 {i} 的 target_ids padding 不正确"
|
| print(f" 样本 {i}: 长度={orig_len}, padding 验证通过")
|
|
|
| print(" ✓ Padding 验证通过")
|
|
|
| print("\n" + "=" * 60)
|
| print("所有测试完成!")
|
| print("=" * 60)
|
|
|