File size: 3,970 Bytes
75b1a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# MyDataset_rib.py
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image

DISEASES = [
    "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
    "Diffuse Nodule", "Edema", "Effusion", "Emphysema", "Enlarged Cardiomediastinum",
    "Fibrosis", "Fracture", "Mass", "Nodule", "Opacity", "Pleural Thickening", "Pneumothorax",
    "Pacemaker",
]
DISEASES_LOWER = {d.lower() for d in DISEASES}


def _resolve_path(rel_or_abs, base_dir):
    if not rel_or_abs:
        return None
    return rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs)


def _load_gray_int(path):
    arr = np.array(Image.open(path).convert("L"))
    return arr.astype(np.int32)


def _load_bin_mask_resize(path, hw):
    H, W = hw
    im = Image.open(path).convert("L")
    if im.size != (W, H):
        im = im.resize((W, H), resample=Image.NEAREST)
    arr = np.array(im)
    return (arr > 0).astype(np.float32)


def _remap_to_sequential(label_map_int):
    vals = np.unique(label_map_int)
    vals = vals[vals != 0]
    if len(vals) == 0:
        return label_map_int.astype(np.int32), 0

    out = np.zeros_like(label_map_int, dtype=np.int32)
    for i, v in enumerate(vals, start=1):
        out[label_map_int == v] = i
    return out, int(out.max())


def _norm(ch, maxv):
    return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32)


def _get_first_disease_name(item):
    attn_list = item.get("attn_list", [])
    if (
        len(attn_list) > 0
        and isinstance(attn_list[0], (list, tuple))
        and len(attn_list[0]) >= 1
    ):
        return str(attn_list[0][0]).strip()
    return None


def _is_target_record(item):
    disease_name = _get_first_disease_name(item)
    if disease_name is None:
        return False
    return disease_name.lower() in DISEASES_LOWER


class MyDataset_rib(Dataset):
    def __init__(self, args, tokenizer):
        self.data_dir = args.train_data_dir
        self.prompt_dir = args.train_data_prompt
        self.tokenizer = tokenizer

        with open(self.prompt_dir, "rt") as f:
            raw_data = [json.loads(l) for l in f if l.strip()]

        self.data = [item for item in raw_data if _is_target_record(item)]
        print(f"Loaded {len(self.data)} valid records from {len(raw_data)} total records.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        organ_path = _resolve_path(item.get("organ", ""), self.data_dir)
        rib_path = _resolve_path(item.get("rib", ""), self.data_dir)
        mask_path = _resolve_path(item.get("mask", ""), self.data_dir)

        # organ
        organ_map_raw = _load_gray_int(organ_path)
        organ_map, organ_max = _remap_to_sequential(organ_map_raw)

        # rib
        rib_map_raw = _load_gray_int(rib_path)
        rib_map, rib_max = _remap_to_sequential(rib_map_raw)

        H, W = organ_map.shape

        # disease channel: binary 0/1
        disease_ch = np.zeros((H, W), dtype=np.float32)
        if mask_path is not None and os.path.exists(mask_path):
            disease_ch = _load_bin_mask_resize(mask_path, (H, W)).astype(np.float32)

        organ_ch = _norm(organ_map, organ_max)
        rib_ch = _norm(rib_map, rib_max)
        empty_ch = np.zeros((H, W), dtype=np.float32)

        original = np.stack([empty_ch, organ_ch, rib_ch], axis=0).astype(np.float32)
        edited = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32)

        enc = self.tokenizer(
            item["prompt"],
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "original_pixel_values": torch.from_numpy(original),
            "edited_pixel_values": torch.from_numpy(edited),
            "input_ids": enc.input_ids.squeeze(0),
        }