File size: 6,056 Bytes
46b244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据集集成脚本:将身份证数据集成到OCR文本训练数据中
同时改进身份证数据的instruction格式
"""

import json
import os
from typing import List, Dict, Any


def load_json_data(file_path: str) -> List[Dict[str, Any]]:
    """加载JSON数据"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"成功加载 {file_path},包含 {len(data)} 条记录")
        return data
    except Exception as e:
        print(f"加载文件失败 {file_path}: {e}")
        return []


def improve_idcard_instruction(original_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    改进身份证数据的instruction格式,参考OCR文本数据的详细写法
    """
    improved_data = []
    
    # 新的详细instruction模板
    detailed_instruction = """请从OCR文本中抽取旅行订单中的游客身份证信息。

提取字段及说明:
name (姓名): 游客的真实姓名
idcard (身份证号): 18位身份证号码,支持末位为X的格式
gender (性别): 根据身份证号倒数第二位数字判断(奇数为男,偶数为女)
phone (电话号码): 游客联系电话,如果OCR文本中没有则为null

严格按照以下JSON格式输出:
[
  {
    "name": "姓名",
    "idcard": "身份证号",
    "gender": "男/女",
    "phone": "电话号码或null"
  }
]

注意事项:
1. 确保身份证号码格式正确(18位数字,末位可为X)
2. 性别根据身份证号自动判断,不依赖OCR文本中的性别信息
3. 如果姓名缺失或无法识别,name字段设为"无"
4. 电话号码如果在OCR文本中不存在,设置为null
5. 输出必须是有效的JSON数组格式"""

    for item in original_data:
        improved_item = item.copy()
        improved_item['instruction'] = detailed_instruction
        improved_data.append(improved_item)
    
    print(f"改进了 {len(improved_data)} 条身份证数据的instruction格式")
    return improved_data


def integrate_datasets(ocr_text_data: List[Dict[str, Any]], 
                      idcard_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    集成两个数据集
    """
    print("正在集成数据集...")
    
    # 合并数据集
    integrated_data = ocr_text_data + idcard_data
    
    print(f"集成完成:")
    print(f"- OCR文本数据:{len(ocr_text_data)} 条")
    print(f"- 身份证数据:{len(idcard_data)} 条")
    print(f"- 总计:{len(integrated_data)} 条")
    
    return integrated_data


def save_integrated_dataset(data: List[Dict[str, Any]], output_path: str):
    """
    保存集成后的数据集
    """
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        print(f"集成数据集已保存到: {output_path}")
    except Exception as e:
        print(f"保存文件失败: {e}")


def validate_dataset(data: List[Dict[str, Any]]) -> bool:
    """
    验证数据集格式正确性
    """
    print("正在验证数据集格式...")
    
    required_fields = ['instruction', 'input', 'output']
    issues = []
    
    for i, item in enumerate(data):
        # 检查必要字段
        for field in required_fields:
            if field not in item:
                issues.append(f"记录 {i}: 缺少字段 '{field}'")
        
        # 检查instruction是否为空
        if 'instruction' in item and not item['instruction'].strip():
            issues.append(f"记录 {i}: instruction字段为空")
        
        # 检查output是否为有效JSON
        if 'output' in item:
            try:
                json.loads(item['output'])
            except json.JSONDecodeError:
                issues.append(f"记录 {i}: output字段不是有效的JSON格式")
    
    if issues:
        print(f"发现 {len(issues)} 个问题:")
        for issue in issues[:10]:  # 只显示前10个问题
            print(f"  - {issue}")
        if len(issues) > 10:
            print(f"  - ... 还有 {len(issues) - 10} 个问题")
        return False
    else:
        print("数据集格式验证通过!")
        return True


def main():
    """
    主函数
    """
    print("开始数据集集成任务...")
    
    # 文件路径
    ocr_text_file = "/home/ziqiang/LLaMA-Factory/data/ocr_text_orders_08_18.json"
    idcard_file = "/home/ziqiang/LLaMA-Factory/data/ocr_idcards_orders.json"
    output_file = "/home/ziqiang/LLaMA-Factory/data/text_idcards_8_18.json"
    
    # 加载数据
    print("\n1. 加载原始数据集...")
    ocr_text_data = load_json_data(ocr_text_file)
    idcard_data = load_json_data(idcard_file)
    
    if not ocr_text_data or not idcard_data:
        print("数据加载失败,终止执行")
        return
    
    # 改进身份证数据的instruction
    print("\n2. 改进身份证数据的instruction格式...")
    improved_idcard_data = improve_idcard_instruction(idcard_data)
    
    # 集成数据集
    print("\n3. 集成数据集...")
    integrated_data = integrate_datasets(ocr_text_data, improved_idcard_data)
    
    # 验证数据集
    print("\n4. 验证数据集格式...")
    if not validate_dataset(integrated_data):
        print("数据集验证失败,请检查数据格式")
        return
    
    # 保存集成后的数据集
    print("\n5. 保存集成数据集...")
    save_integrated_dataset(integrated_data, output_file)
    
    print(f"\n✅ 数据集集成完成!")
    print(f"集成后的数据集保存在: {output_file}")
    print(f"总记录数: {len(integrated_data)}")
    
    # 显示样本统计
    ocr_count = len(ocr_text_data)
    idcard_count = len(improved_idcard_data)
    print(f"\n数据构成:")
    print(f"- OCR文本数据: {ocr_count} 条 ({ocr_count/len(integrated_data)*100:.1f}%)")
    print(f"- 身份证数据: {idcard_count} 条 ({idcard_count/len(integrated_data)*100:.1f}%)")


if __name__ == "__main__":
    main()