File size: 3,942 Bytes
46b244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
执行数据平衡的主脚本

结合你的具体需求:
- 新叶古村-新叶古村门票: 1 -> 5 (+4)
- 大慈岩-大慈岩索道: 2 -> 5 (+3) 
- 其他低频资源也会被相应增强
"""

import json
import sys
import os
from advanced_data_augmentation import AdvancedDataAugmenter

def load_training_data(file_path: str):
    """加载原始训练数据"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def merge_enhanced_samples(original_data, enhanced_samples):
    """合并原始数据和增强样本"""
    return original_data + enhanced_samples

def analyze_final_distribution(data):
    """分析最终的数据分布"""
    from collections import Counter
    
    resource_counts = Counter()
    
    for item in data:
        if 'output' in item:
            try:
                output_data = json.loads(item['output'])
                if 'resource_names' in output_data:
                    resources = output_data['resource_names']
                    for resource in resources:
                        resource_counts[resource] += 1
            except:
                continue
    
    print("📊 最终数据分布:")
    print("-" * 60)
    
    # 重点关注的低频资源
    focus_resources = [
        "新叶古村-新叶古村门票",
        "大慈岩-大慈岩索道", 
        "灵栖洞-灵栖洞西游魔毯",
        "宿江公司-江清月近人实景演艺门票"
    ]
    
    print("🎯 重点关注的资源:")
    for resource in focus_resources:
        count = resource_counts.get(resource, 0)
        print(f"   {resource}: {count}")
    
    print(f"\n📈 所有资源分布 (总计 {len(resource_counts)} 种资源):")
    for resource, count in resource_counts.most_common():
        status = "✅" if count >= 5 else "⚠️"
        print(f"   {status} {resource}: {count}")

def main():
    # 输入文件路径 - 根据你的实际路径调整
    input_files = [
        "/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_14_test_v4.json"
    ]
    
    # 尝试找到存在的训练数据文件
    training_file = None
    for file_path in input_files:
        if os.path.exists(file_path):
            training_file = file_path
            break
    
    if not training_file:
        print("❌ 未找到训练数据文件,请检查路径:")
        for file_path in input_files:
            print(f"   {file_path}")
        return
    
    print(f"📂 使用训练数据文件: {training_file}")
    print("=" * 60)
    
    # 加载原始数据
    print("📥 加载原始训练数据...")
    original_data = load_training_data(training_file)
    print(f"   原始样本数: {len(original_data)}")
    
    # 生成增强样本
    print("\n🔄 生成增强样本...")
    augmenter = AdvancedDataAugmenter()
    enhanced_samples = augmenter.generate_all_samples()
    
    # 合并数据
    print(f"\n🔗 合并原始数据和增强样本...")
    balanced_data = merge_enhanced_samples(original_data, enhanced_samples)
    print(f"   合并后样本数: {len(balanced_data)}")
    print(f"   新增样本数: {len(enhanced_samples)}")
    
    # 保存平衡后的数据
    output_file = "balanced_training_data.json"
    print(f"\n💾 保存平衡后的数据到: {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(balanced_data, f, ensure_ascii=False, indent=2)
    
    # 分析最终分布
    print(f"\n📊 分析最终数据分布...")
    analyze_final_distribution(balanced_data)
    
    print(f"\n🎉 数据平衡完成!")
    print("📋 建议的下一步:")
    print("   1. 使用 balanced_training_data.json 重新训练模型")
    print("   2. 在验证集上测试性能改进")
    print("   3. 特别关注新叶古村、大慈岩索道等低频资源的识别效果")

if __name__ == "__main__":
    main()