File size: 6,927 Bytes
46b244e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
训练数据平衡脚本 - 针对旅游资源名称抽取任务
主要功能:
1. 分析当前数据分布
2. 对低频资源进行上采样
3. 生成数据增强样本
4. 输出平衡后的训练集
"""
import json
import random
import copy
from collections import Counter, defaultdict
from typing import List, Dict, Any
import argparse
class DataBalancer:
def __init__(self, input_file: str):
self.input_file = input_file
self.data = self.load_data()
self.resource_counts = Counter()
self.resource_samples = defaultdict(list)
self.analyze_distribution()
def load_data(self) -> List[Dict]:
"""加载训练数据"""
with open(self.input_file, 'r', encoding='utf-8') as f:
return json.load(f)
def analyze_distribution(self):
"""分析资源分布"""
for idx, item in enumerate(self.data):
if 'output' in item:
try:
output_data = json.loads(item['output'])
if 'resource_names' in output_data:
resources = output_data['resource_names']
for resource in resources:
self.resource_counts[resource] += 1
self.resource_samples[resource].append(idx)
except:
continue
def get_balance_strategy(self, target_min_samples: int = 5) -> Dict[str, int]:
"""
计算平衡策略
Args:
target_min_samples: 目标最小样本数
Returns:
Dict[资源名称, 需要增加的样本数]
"""
balance_strategy = {}
print("📊 当前资源分布分析:")
print("-" * 50)
for resource, count in self.resource_counts.most_common():
if count < target_min_samples:
needed = target_min_samples - count
balance_strategy[resource] = needed
print(f"❌ {resource}: {count} -> {target_min_samples} (需要+{needed})")
else:
print(f"✅ {resource}: {count}")
return balance_strategy
def create_augmented_sample(self, original_idx: int, target_resource: str) -> Dict[str, Any]:
"""
创建数据增强样本
策略:
1. 保持原有的instruction不变
2. 修改input中的关键信息(日期、人数、联系人等)
3. 保持目标资源在output中
"""
original = copy.deepcopy(self.data[original_idx])
# 日期变换
dates = ["7月15日", "7月16日", "7月19日", "7月21日", "7月22日", "7月25日", "7月26日", "8月1日", "8月2日", "8月5日"]
# 人数变换
people_counts = ["15人", "25人", "35人", "45人", "55人", "8人", "12人", "18人", "22人", "28人"]
# 联系人变换(保持格式)
phone_endings = ["1234", "5678", "9012", "3456", "7890", "2468", "1357", "9753", "8642", "0246"]
input_text = original['input']
# 简单的文本替换进行数据增强
for date in ["7月17日", "7月18日", "7月20日", "7月28日", "7月29日", "7月30日", "7月31日"]:
if date in input_text:
input_text = input_text.replace(date, random.choice(dates))
break
# 替换人数
import re
people_pattern = r'\d+人'
matches = re.findall(people_pattern, input_text)
if matches:
for match in matches:
input_text = input_text.replace(match, random.choice(people_counts), 1)
# 替换电话号码后四位
phone_pattern = r'1[3-9]\d{9}'
def replace_phone(match):
phone = match.group()
return phone[:-4] + random.choice(phone_endings)
input_text = re.sub(phone_pattern, replace_phone, input_text)
# 创建新样本
new_sample = copy.deepcopy(original)
new_sample['input'] = input_text
return new_sample
def balance_data(self, target_min_samples: int = 5) -> List[Dict[str, Any]]:
"""
平衡数据集
Args:
target_min_samples: 目标最小样本数
Returns:
平衡后的数据集
"""
balance_strategy = self.get_balance_strategy(target_min_samples)
if not balance_strategy:
print("✅ 数据已经平衡,无需调整")
return self.data
print(f"\n🔄 开始数据平衡,目标最小样本数: {target_min_samples}")
print("-" * 50)
balanced_data = copy.deepcopy(self.data)
for resource, needed_count in balance_strategy.items():
print(f"📈 正在增强 '{resource}' 的样本...")
# 获取该资源的原始样本
original_samples = self.resource_samples[resource]
for i in range(needed_count):
# 随机选择一个原始样本进行增强
source_idx = random.choice(original_samples)
augmented_sample = self.create_augmented_sample(source_idx, resource)
balanced_data.append(augmented_sample)
print(f" ✅ 已添加 {needed_count} 个增强样本")
return balanced_data
def save_balanced_data(self, balanced_data: List[Dict], output_file: str):
"""保存平衡后的数据"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(balanced_data, f, ensure_ascii=False, indent=2)
print(f"\n💾 已保存平衡后的数据到: {output_file}")
print(f" 原始样本数: {len(self.data)}")
print(f" 平衡后样本数: {len(balanced_data)}")
print(f" 新增样本数: {len(balanced_data) - len(self.data)}")
def main():
parser = argparse.ArgumentParser(description='平衡旅游资源训练数据')
parser.add_argument('--input', required=True, help='输入的训练数据文件')
parser.add_argument('--output', required=True, help='输出的平衡数据文件')
parser.add_argument('--min-samples', type=int, default=5, help='目标最小样本数 (默认: 5)')
args = parser.parse_args()
print("🚀 开始数据平衡流程...")
print("=" * 60)
# 初始化平衡器
balancer = DataBalancer(args.input)
# 执行平衡
balanced_data = balancer.balance_data(args.min_samples)
# 保存结果
balancer.save_balanced_data(balanced_data, args.output)
print("\n🎉 数据平衡完成!")
if __name__ == "__main__":
main()
|