File size: 3,125 Bytes
1c980b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
from collections import defaultdict

def transform_json_structure(input_path, output_path):
    """
    将RSVQA原始JSON格式转换为按图片分组的结构
    :param input_path: 输入JSON文件路径
    :param output_path: 输出JSON文件路径
    """
    # 读取原始数据
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            original_data = json.load(f)
            question_list = original_data.get('merged_data', [])
    except Exception as e:
        raise RuntimeError(f"读取输入文件失败: {str(e)}")

    # 创建分组容器(自动处理键不存在的情况)
    image_groups = defaultdict(list)
    
    # 分组处理原始数据
    for qa_pair in question_list:
        try:
            img_id = qa_pair['img_id']
            # 过滤无效数据
            if not isinstance(img_id, int) or img_id < 0:
                continue
            
            # 每组最多保留15个问题
            if len(image_groups[img_id]) < 15:
                image_groups[img_id].append({
                    'question': qa_pair.get('question', ''),
                    'answer': qa_pair.get('answer', '')
                })
        except KeyError as ke:
            print(f"跳过缺少关键字段的数据: {str(ke)}")
            continue

    # 构建新数据结构
    transformed_data = []
    for index, (img_id, qa_pairs) in enumerate(image_groups.items()):
        # 生成媒体文件路径
        media_path = f"./data/RSVQA/{img_id}.png"
        
        # 提取问题和答案列表
        questions = [pair['question'] for pair in qa_pairs]
        answers = [pair['answer'] for pair in qa_pairs]
        
        # 构建输出格式
        transformed_data.append({
            "index": img_id,
            "media_type": "image",
            "media_paths": media_path,
            "description": "",
            "task_type": "Vision-Question-Answer",
            "question": questions,
            "question_type": "free-form",
            "annotations": [],
            "options": [],
            "answer": answers,
            "source": "RSVQA",
            "domain": "Satellite-Remote-Sensing"
        })

    # 保存转换后的数据
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(transformed_data, f, indent=2, ensure_ascii=False)
    except Exception as e:
        raise RuntimeError(f"写入输出文件失败: {str(e)}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='RSVQA数据格式转换工具')
    parser.add_argument('-i','--input', type=str, required=True, help='输入JSON文件路径')
    parser.add_argument('-o','--output', type=str, default='transformed.json', 
                       help='输出JSON文件路径 (默认: transformed.json)')

    args = parser.parse_args()

    try:
        transform_json_structure(args.input, args.output)
        print(f"转换成功!输出文件已保存至: {args.output}")
    except Exception as e:
        print(f"处理过程中发生错误: {str(e)}")
        exit(1)