| | |
| | |
| | """ |
| | 执行数据平衡的主脚本 |
| | |
| | 结合你的具体需求: |
| | - 新叶古村-新叶古村门票: 1 -> 5 (+4) |
| | - 大慈岩-大慈岩索道: 2 -> 5 (+3) |
| | - 其他低频资源也会被相应增强 |
| | """ |
| |
|
| | import json |
| | import sys |
| | import os |
| | from advanced_data_augmentation import AdvancedDataAugmenter |
| |
|
| | def load_training_data(file_path: str): |
| | """加载原始训练数据""" |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | return json.load(f) |
| |
|
| | def merge_enhanced_samples(original_data, enhanced_samples): |
| | """合并原始数据和增强样本""" |
| | return original_data + enhanced_samples |
| |
|
| | def analyze_final_distribution(data): |
| | """分析最终的数据分布""" |
| | from collections import Counter |
| | |
| | resource_counts = Counter() |
| | |
| | for item in data: |
| | if 'output' in item: |
| | try: |
| | output_data = json.loads(item['output']) |
| | if 'resource_names' in output_data: |
| | resources = output_data['resource_names'] |
| | for resource in resources: |
| | resource_counts[resource] += 1 |
| | except: |
| | continue |
| | |
| | print("📊 最终数据分布:") |
| | print("-" * 60) |
| | |
| | |
| | focus_resources = [ |
| | "新叶古村-新叶古村门票", |
| | "大慈岩-大慈岩索道", |
| | "灵栖洞-灵栖洞西游魔毯", |
| | "宿江公司-江清月近人实景演艺门票" |
| | ] |
| | |
| | print("🎯 重点关注的资源:") |
| | for resource in focus_resources: |
| | count = resource_counts.get(resource, 0) |
| | print(f" {resource}: {count}") |
| | |
| | print(f"\n📈 所有资源分布 (总计 {len(resource_counts)} 种资源):") |
| | for resource, count in resource_counts.most_common(): |
| | status = "✅" if count >= 5 else "⚠️" |
| | print(f" {status} {resource}: {count}") |
| |
|
| | def main(): |
| | |
| | input_files = [ |
| | "/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_14_test_v4.json" |
| | ] |
| | |
| | |
| | training_file = None |
| | for file_path in input_files: |
| | if os.path.exists(file_path): |
| | training_file = file_path |
| | break |
| | |
| | if not training_file: |
| | print("❌ 未找到训练数据文件,请检查路径:") |
| | for file_path in input_files: |
| | print(f" {file_path}") |
| | return |
| | |
| | print(f"📂 使用训练数据文件: {training_file}") |
| | print("=" * 60) |
| | |
| | |
| | print("📥 加载原始训练数据...") |
| | original_data = load_training_data(training_file) |
| | print(f" 原始样本数: {len(original_data)}") |
| | |
| | |
| | print("\n🔄 生成增强样本...") |
| | augmenter = AdvancedDataAugmenter() |
| | enhanced_samples = augmenter.generate_all_samples() |
| | |
| | |
| | print(f"\n🔗 合并原始数据和增强样本...") |
| | balanced_data = merge_enhanced_samples(original_data, enhanced_samples) |
| | print(f" 合并后样本数: {len(balanced_data)}") |
| | print(f" 新增样本数: {len(enhanced_samples)}") |
| | |
| | |
| | output_file = "balanced_training_data.json" |
| | print(f"\n💾 保存平衡后的数据到: {output_file}") |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | json.dump(balanced_data, f, ensure_ascii=False, indent=2) |
| | |
| | |
| | print(f"\n📊 分析最终数据分布...") |
| | analyze_final_distribution(balanced_data) |
| | |
| | print(f"\n🎉 数据平衡完成!") |
| | print("📋 建议的下一步:") |
| | print(" 1. 使用 balanced_training_data.json 重新训练模型") |
| | print(" 2. 在验证集上测试性能改进") |
| | print(" 3. 特别关注新叶古村、大慈岩索道等低频资源的识别效果") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|