import os import pickle as pkl DATA_DIR = '/gemini/space/wrz/AffordanceNet/data' # 新增一个路径修复函数 def resolve_path(path): """ 如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径 """ if path.startswith('./data/'): # 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面 return os.path.join(DATA_DIR, path[7:]) elif path.startswith('./'): # 兼容其他情况 return os.path.join(os.path.dirname(DATA_DIR), path[2:]) return path def get_data_paths(): """Retrieve train/val/reasoning/non-reasoning pkl file paths.""" all_files = os.listdir(DATA_DIR) train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')] val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')] reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')] non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths] return train_paths, reasoning_paths, non_reasoning_paths def check_file_exists(file_path, description=""): """Assert that the file exists, otherwise raise an error.""" assert os.path.exists(file_path), f"{description} does not exist: {file_path}" def check_train_data(train_path): """Check frame and mask paths for each sample in training data.""" print(f"[Train] Checking: {train_path}") with open(train_path, "rb") as f: data = pkl.load(f) for item in data: # 修改这里:在检查之前先转换路径 real_frame_path = resolve_path(item["frame_path"]) real_mask_path = resolve_path(item["mask_path"]) check_file_exists(real_frame_path, "Frame path") check_file_exists(real_mask_path, "Mask path") print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}") def check_val_data(val_path, reasoning=False): """Check validation data paths depending on reasoning mode.""" tag = "Reasoning Val" if reasoning else "Non-Reasoning Val" print(f"[{tag}] Checking: {val_path}") with open(val_path, "rb") as f: data = pkl.load(f) if reasoning: for item in data: # 修改这里 real_frame_path = resolve_path(item["frame_path"]) real_mask_path = resolve_path(item["mask_path"]) check_file_exists(real_frame_path, "Frame path") check_file_exists(real_mask_path, "Mask path") print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}") else: total_images = 0 for class_name, image_list in data.get('images', {}).items(): for image_path in image_list: # 修改这里 check_file_exists(resolve_path(image_path), "Image path") total_images += len(image_list) for class_name, label_list in data.get('labels', {}).items(): for label_path in label_list: # 修改这里 check_file_exists(resolve_path(label_path), "Label path") print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}") def main(): train_paths, reasoning_paths, non_reasoning_paths = get_data_paths() for train_path in train_paths: check_train_data(train_path) for val_path in non_reasoning_paths: check_val_data(val_path, reasoning=False) for val_path in reasoning_paths: check_val_data(val_path, reasoning=True) if __name__ == "__main__": main()