liangyi_LLaMA_Factory / data /preprocess_data /execute_data_balance.py
Mickey25's picture
Upload folder using huggingface_hub
46b244e verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
执行数据平衡的主脚本
结合你的具体需求:
- 新叶古村-新叶古村门票: 1 -> 5 (+4)
- 大慈岩-大慈岩索道: 2 -> 5 (+3)
- 其他低频资源也会被相应增强
"""
import json
import sys
import os
from advanced_data_augmentation import AdvancedDataAugmenter
def load_training_data(file_path: str):
"""加载原始训练数据"""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def merge_enhanced_samples(original_data, enhanced_samples):
"""合并原始数据和增强样本"""
return original_data + enhanced_samples
def analyze_final_distribution(data):
"""分析最终的数据分布"""
from collections import Counter
resource_counts = Counter()
for item in data:
if 'output' in item:
try:
output_data = json.loads(item['output'])
if 'resource_names' in output_data:
resources = output_data['resource_names']
for resource in resources:
resource_counts[resource] += 1
except:
continue
print("📊 最终数据分布:")
print("-" * 60)
# 重点关注的低频资源
focus_resources = [
"新叶古村-新叶古村门票",
"大慈岩-大慈岩索道",
"灵栖洞-灵栖洞西游魔毯",
"宿江公司-江清月近人实景演艺门票"
]
print("🎯 重点关注的资源:")
for resource in focus_resources:
count = resource_counts.get(resource, 0)
print(f" {resource}: {count}")
print(f"\n📈 所有资源分布 (总计 {len(resource_counts)} 种资源):")
for resource, count in resource_counts.most_common():
status = "✅" if count >= 5 else "⚠️"
print(f" {status} {resource}: {count}")
def main():
# 输入文件路径 - 根据你的实际路径调整
input_files = [
"/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_14_test_v4.json"
]
# 尝试找到存在的训练数据文件
training_file = None
for file_path in input_files:
if os.path.exists(file_path):
training_file = file_path
break
if not training_file:
print("❌ 未找到训练数据文件,请检查路径:")
for file_path in input_files:
print(f" {file_path}")
return
print(f"📂 使用训练数据文件: {training_file}")
print("=" * 60)
# 加载原始数据
print("📥 加载原始训练数据...")
original_data = load_training_data(training_file)
print(f" 原始样本数: {len(original_data)}")
# 生成增强样本
print("\n🔄 生成增强样本...")
augmenter = AdvancedDataAugmenter()
enhanced_samples = augmenter.generate_all_samples()
# 合并数据
print(f"\n🔗 合并原始数据和增强样本...")
balanced_data = merge_enhanced_samples(original_data, enhanced_samples)
print(f" 合并后样本数: {len(balanced_data)}")
print(f" 新增样本数: {len(enhanced_samples)}")
# 保存平衡后的数据
output_file = "balanced_training_data.json"
print(f"\n💾 保存平衡后的数据到: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(balanced_data, f, ensure_ascii=False, indent=2)
# 分析最终分布
print(f"\n📊 分析最终数据分布...")
analyze_final_distribution(balanced_data)
print(f"\n🎉 数据平衡完成!")
print("📋 建议的下一步:")
print(" 1. 使用 balanced_training_data.json 重新训练模型")
print(" 2. 在验证集上测试性能改进")
print(" 3. 特别关注新叶古村、大慈岩索道等低频资源的识别效果")
if __name__ == "__main__":
main()