| import SimpleITK as sitk |
| import numpy as np |
| import os |
| from tqdm import tqdm |
| import pandas as pd |
| import hashlib |
|
|
| |
| scans_path = '/data/wxh/Medical/tmz/metrics/brain_bleed/SAMIHS/BHSD/ct_scans_brain_window' |
| masks_path = '/data/wxh/Medical/tmz/metrics/brain_bleed/SAMIHS/BHSD/label_192/ground truths' |
|
|
| scans_path_2d = '/data/wxh/Medical/tmz/metrics/brain_bleed/SAMIHS/BHSD/ct_2d' |
| masks_path_2d = '/data/wxh/Medical/tmz/metrics/brain_bleed/SAMIHS/BHSD/mask_2d' |
| csv_out = '/data/wxh/Medical/tmz/metrics/brain_bleed/SAMIHS/BHSD/bhsd_2d_index.csv' |
|
|
| os.makedirs(scans_path_2d, exist_ok=True) |
| os.makedirs(masks_path_2d, exist_ok=True) |
|
|
| |
| ALLOW_EXT = ('.nii', '.nii.gz') |
| TAG = 'BHSD' |
|
|
| |
| ground_truth_ids = {os.path.splitext(f)[0].replace('.nii', '').replace('.gz', '') |
| for f in os.listdir(masks_path) if f.endswith(ALLOW_EXT)} |
|
|
| def removesuffix(name: str, suffix: str) -> str: |
| return name[:-len(suffix)] if name.endswith(suffix) else name |
|
|
| def get_case_id_from_scan(fname: str) -> str: |
| """从 ct_scans_brain_window 的文件名里抽出与 mask 同名的 case_id""" |
| base = os.path.basename(fname) |
| |
| for suf in ('_brain.nii.gz', '_brain.nii', '.nii.gz', '.nii'): |
| if base.endswith(suf): |
| return base[:-len(suf)] |
| return os.path.splitext(base)[0] |
|
|
| def stable_fold(case_id: str, k: int = 5) -> int: |
| """基于 case_id 的稳定哈希,避免同病例跨折泄露""" |
| h = hashlib.md5(case_id.encode('utf-8')).hexdigest() |
| return int(h, 16) % k |
|
|
| rows = [] |
| row_id = 0 |
|
|
| for pa in tqdm(os.listdir(scans_path)): |
| if not pa.endswith(ALLOW_EXT): |
| continue |
|
|
| scan_fp = os.path.join(scans_path, pa) |
| case_id = get_case_id_from_scan(pa) |
|
|
| |
| if case_id not in ground_truth_ids: |
| continue |
|
|
| |
| mask_fp_gz = os.path.join(masks_path, f"{case_id}.nii.gz") |
| mask_fp_nii = os.path.join(masks_path, f"{case_id}.nii") |
| mask_fp = mask_fp_gz if os.path.exists(mask_fp_gz) else mask_fp_nii |
| if not os.path.exists(mask_fp): |
| print(f"[WARN] mask not found for {case_id}, skip.") |
| continue |
|
|
| |
| scan_img = sitk.ReadImage(scan_fp) |
| mask_img = sitk.ReadImage(mask_fp) |
| scan_arr = sitk.GetArrayFromImage(scan_img) |
| label_arr = sitk.GetArrayFromImage(mask_img) |
|
|
| |
| label_arr = (label_arr > 0).astype(np.uint8) |
|
|
| |
| fold_id = stable_fold(case_id, k=5) |
|
|
| |
| for i in range(scan_arr.shape[0]): |
| img_name = f"{TAG}_{case_id}_{i:03d}.npy" |
| mask_name = f"{TAG}_{case_id}_{i:03d}_seg.npy" |
|
|
| np.save(os.path.join(scans_path_2d, img_name), scan_arr[i]) |
| np.save(os.path.join(masks_path_2d, mask_name), label_arr[i]) |
|
|
| |
| img_rel = f"ct_2d/{img_name}" |
| mask_rel = f"mask_2d/{mask_name}" |
| pos_pixels = float(label_arr[i].sum()) |
|
|
| rows.append({ |
| "id": row_id, |
| "img": img_rel, |
| "gt": mask_rel, |
| "fold": fold_id, |
| "label": pos_pixels |
| }) |
| row_id += 1 |
|
|
| |
| |
|
|
| |
| df = pd.DataFrame(rows, columns=["id", "img", "gt", "fold", "label"]) |
| df.to_csv(csv_out, index=False) |
| print(f"Done. Wrote {len(df)} rows to {csv_out}") |