File size: 6,927 Bytes
46b244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
训练数据平衡脚本 - 针对旅游资源名称抽取任务

主要功能:
1. 分析当前数据分布
2. 对低频资源进行上采样
3. 生成数据增强样本
4. 输出平衡后的训练集
"""

import json
import random
import copy
from collections import Counter, defaultdict
from typing import List, Dict, Any
import argparse

class DataBalancer:
    def __init__(self, input_file: str):
        self.input_file = input_file
        self.data = self.load_data()
        self.resource_counts = Counter()
        self.resource_samples = defaultdict(list)
        self.analyze_distribution()
    
    def load_data(self) -> List[Dict]:
        """加载训练数据"""
        with open(self.input_file, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def analyze_distribution(self):
        """分析资源分布"""
        for idx, item in enumerate(self.data):
            if 'output' in item:
                try:
                    output_data = json.loads(item['output'])
                    if 'resource_names' in output_data:
                        resources = output_data['resource_names']
                        for resource in resources:
                            self.resource_counts[resource] += 1
                            self.resource_samples[resource].append(idx)
                except:
                    continue
    
    def get_balance_strategy(self, target_min_samples: int = 5) -> Dict[str, int]:
        """
        计算平衡策略
        
        Args:
            target_min_samples: 目标最小样本数
            
        Returns:
            Dict[资源名称, 需要增加的样本数]
        """
        balance_strategy = {}
        
        print("📊 当前资源分布分析:")
        print("-" * 50)
        
        for resource, count in self.resource_counts.most_common():
            if count < target_min_samples:
                needed = target_min_samples - count
                balance_strategy[resource] = needed
                print(f"❌ {resource}: {count} -> {target_min_samples} (需要+{needed})")
            else:
                print(f"✅ {resource}: {count}")
        
        return balance_strategy
    
    def create_augmented_sample(self, original_idx: int, target_resource: str) -> Dict[str, Any]:
        """
        创建数据增强样本
        
        策略:
        1. 保持原有的instruction不变
        2. 修改input中的关键信息(日期、人数、联系人等)
        3. 保持目标资源在output中
        """
        original = copy.deepcopy(self.data[original_idx])
        
        # 日期变换
        dates = ["7月15日", "7月16日", "7月19日", "7月21日", "7月22日", "7月25日", "7月26日", "8月1日", "8月2日", "8月5日"]
        
        # 人数变换
        people_counts = ["15人", "25人", "35人", "45人", "55人", "8人", "12人", "18人", "22人", "28人"]
        
        # 联系人变换(保持格式)
        phone_endings = ["1234", "5678", "9012", "3456", "7890", "2468", "1357", "9753", "8642", "0246"]
        
        input_text = original['input']
        
        # 简单的文本替换进行数据增强
        for date in ["7月17日", "7月18日", "7月20日", "7月28日", "7月29日", "7月30日", "7月31日"]:
            if date in input_text:
                input_text = input_text.replace(date, random.choice(dates))
                break
        
        # 替换人数
        import re
        people_pattern = r'\d+人'
        matches = re.findall(people_pattern, input_text)
        if matches:
            for match in matches:
                input_text = input_text.replace(match, random.choice(people_counts), 1)
        
        # 替换电话号码后四位
        phone_pattern = r'1[3-9]\d{9}'
        def replace_phone(match):
            phone = match.group()
            return phone[:-4] + random.choice(phone_endings)
        
        input_text = re.sub(phone_pattern, replace_phone, input_text)
        
        # 创建新样本
        new_sample = copy.deepcopy(original)
        new_sample['input'] = input_text
        
        return new_sample
    
    def balance_data(self, target_min_samples: int = 5) -> List[Dict[str, Any]]:
        """
        平衡数据集
        
        Args:
            target_min_samples: 目标最小样本数
            
        Returns:
            平衡后的数据集
        """
        balance_strategy = self.get_balance_strategy(target_min_samples)
        
        if not balance_strategy:
            print("✅ 数据已经平衡,无需调整")
            return self.data
        
        print(f"\n🔄 开始数据平衡,目标最小样本数: {target_min_samples}")
        print("-" * 50)
        
        balanced_data = copy.deepcopy(self.data)
        
        for resource, needed_count in balance_strategy.items():
            print(f"📈 正在增强 '{resource}' 的样本...")
            
            # 获取该资源的原始样本
            original_samples = self.resource_samples[resource]
            
            for i in range(needed_count):
                # 随机选择一个原始样本进行增强
                source_idx = random.choice(original_samples)
                augmented_sample = self.create_augmented_sample(source_idx, resource)
                balanced_data.append(augmented_sample)
            
            print(f"   ✅ 已添加 {needed_count} 个增强样本")
        
        return balanced_data
    
    def save_balanced_data(self, balanced_data: List[Dict], output_file: str):
        """保存平衡后的数据"""
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(balanced_data, f, ensure_ascii=False, indent=2)
        
        print(f"\n💾 已保存平衡后的数据到: {output_file}")
        print(f"   原始样本数: {len(self.data)}")
        print(f"   平衡后样本数: {len(balanced_data)}")
        print(f"   新增样本数: {len(balanced_data) - len(self.data)}")

def main():
    parser = argparse.ArgumentParser(description='平衡旅游资源训练数据')
    parser.add_argument('--input', required=True, help='输入的训练数据文件')
    parser.add_argument('--output', required=True, help='输出的平衡数据文件')
    parser.add_argument('--min-samples', type=int, default=5, help='目标最小样本数 (默认: 5)')
    
    args = parser.parse_args()
    
    print("🚀 开始数据平衡流程...")
    print("=" * 60)
    
    # 初始化平衡器
    balancer = DataBalancer(args.input)
    
    # 执行平衡
    balanced_data = balancer.balance_data(args.min_samples)
    
    # 保存结果
    balancer.save_balanced_data(balanced_data, args.output)
    
    print("\n🎉 数据平衡完成!")

if __name__ == "__main__":
    main()