File size: 3,563 Bytes
0453c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import pickle as pkl

DATA_DIR = '/gemini/space/wrz/AffordanceNet/data'

# 新增一个路径修复函数
def resolve_path(path):
    """
    如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径
    """
    if path.startswith('./data/'):
        # 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面
        return os.path.join(DATA_DIR, path[7:])
    elif path.startswith('./'):
        # 兼容其他情况
        return os.path.join(os.path.dirname(DATA_DIR), path[2:])
    return path


def get_data_paths():
    """Retrieve train/val/reasoning/non-reasoning pkl file paths."""
    all_files = os.listdir(DATA_DIR)
    train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')]
    val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')]
    reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')]
    non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths]

    return train_paths, reasoning_paths, non_reasoning_paths


def check_file_exists(file_path, description=""):
    """Assert that the file exists, otherwise raise an error."""
    assert os.path.exists(file_path), f"{description} does not exist: {file_path}"


def check_train_data(train_path):
    """Check frame and mask paths for each sample in training data."""
    print(f"[Train] Checking: {train_path}")
    with open(train_path, "rb") as f:
        data = pkl.load(f)

    for item in data:
        # 修改这里:在检查之前先转换路径
        real_frame_path = resolve_path(item["frame_path"])
        real_mask_path = resolve_path(item["mask_path"])
        
        check_file_exists(real_frame_path, "Frame path")
        check_file_exists(real_mask_path, "Mask path")

    print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}")


def check_val_data(val_path, reasoning=False):
    """Check validation data paths depending on reasoning mode."""
    tag = "Reasoning Val" if reasoning else "Non-Reasoning Val"
    print(f"[{tag}] Checking: {val_path}")

    with open(val_path, "rb") as f:
        data = pkl.load(f)

    if reasoning:
        for item in data:
            # 修改这里
            real_frame_path = resolve_path(item["frame_path"])
            real_mask_path = resolve_path(item["mask_path"])
            
            check_file_exists(real_frame_path, "Frame path")
            check_file_exists(real_mask_path, "Mask path")
        print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}")
    else:
        total_images = 0
        for class_name, image_list in data.get('images', {}).items():
            for image_path in image_list:
                # 修改这里
                check_file_exists(resolve_path(image_path), "Image path")
            total_images += len(image_list)

        for class_name, label_list in data.get('labels', {}).items():
            for label_path in label_list:
                # 修改这里
                check_file_exists(resolve_path(label_path), "Label path")

        print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}")


def main():
    train_paths, reasoning_paths, non_reasoning_paths = get_data_paths()

    for train_path in train_paths:
        check_train_data(train_path)

    for val_path in non_reasoning_paths:
        check_val_data(val_path, reasoning=False)

    for val_path in reasoning_paths:
        check_val_data(val_path, reasoning=True)


if __name__ == "__main__":
    main()