File size: 4,807 Bytes
75b1a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# MyDataset_rib.py
import os, json
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image

# 固定疾病索引:背景=0,疾病从1开始
DISEASES = ["Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
            "Diffuse Nodule", "Effusion", "Emphysema", "Fibrosis", "Fracture",
            "Mass", "Nodule", "Pleural Thickening", "Pneumothorax"]
DISEASE_TO_IDX = {name: i+1 for i, name in enumerate(DISEASES)}  # 1..len(DISEASES)

def _load_gray_int(path):
    arr = np.array(Image.open(path).convert("L"))
    return arr.astype(np.int32)

def _remap_to_sequential(label_map_int):
    # 把任意像素值集合映射到 0..K,保持0为背景
    vals = np.unique(label_map_int)
    vals = vals[vals != 0]
    if len(vals) == 0:
        return label_map_int.astype(np.int32), 0
    out = np.zeros_like(label_map_int, dtype=np.int32)
    for i, v in enumerate(vals, start=1):
        out[label_map_int == v] = i
    return out, int(out.max())

def _merge_diseases_from_attn_list(attn_list, base_dir, hw):
    H, W = hw
    items = []
    for name, rel_or_abs in attn_list:
        idx = DISEASE_TO_IDX.get(name, None)
        if idx is None:
            continue
        path = rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs)
        m = (_load_gray_int(path) > 0)
        area = int(m.sum())
        items.append((area, idx, m))
    # 先大后小:小的后写入,会覆盖大的 => “小的不能被大的盖住”
    items.sort(key=lambda x: (-x[0], x[1]))
    disease_map = np.zeros((H, W), dtype=np.int32)
    for area, idx, m in items:
        if area == 0:
            continue
        disease_map[m] = idx
    max_idx = int(disease_map.max())
    return disease_map, max_idx

class MyDataset_rib(Dataset):
    def __init__(self, args, tokenizer):
        self.data_dir   = args.train_data_dir
        self.prompt_dir = args.train_data_prompt
        self.tokenizer  = tokenizer
        with open(self.prompt_dir, 'rt') as f:
            self.data = [json.loads(l) for l in f if l.strip()]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # 目标图像
        img_path = item['file_name']
        img_path = img_path if os.path.isabs(img_path) else os.path.join(self.data_dir, img_path)
        img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32)
        img = img / 127.5 - 1.0                              # [-1,1]
        pixel_values = torch.from_numpy(img).permute(2,0,1)  # (3,H,W)

        # organ 标签(单图,多器官:像素值=器官ID)
        organ_path = item['organ']
        organ_path = organ_path if os.path.isabs(organ_path) else os.path.join(self.data_dir, organ_path)
        organ_map_raw = _load_gray_int(organ_path)
        organ_map, organ_max = _remap_to_sequential(organ_map_raw)

        # rib 标签(单图,多肋骨:像素值=肋骨ID),来自 overlap
        rib_path = os.path.join(self.data_dir, "rib", item['file_name'].split('/')[0], os.path.basename(item['file_name']))
        # rib_path = item['file_name']
        # rib_path = rib_path if os.path.isabs(rib_path) else os.path.join(self.data_dir, rib_path)
        rib_map_raw = _load_gray_int(rib_path)
        # rib_map, rib_max = _remap_to_sequential(rib_map_raw)
        rib_bin = (rib_map_raw > 0).astype(np.float32)

        H, W = organ_map.shape

        # disease:从 attn_list 读取多张二值图并合并为单通道
        attn_list = item.get('attn_list', [])
        disease_map, disease_max = (np.zeros((H,W), dtype=np.int32), 0)
        if len(attn_list) > 0:
            disease_map, disease_max = _merge_diseases_from_attn_list(attn_list, self.data_dir, (H, W))

        # 各通道归一化到[0,1](仅缩放,不改变标签关系)
        def _norm(ch, maxv):
            return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32)

        organ_ch   = _norm(organ_map,   organ_max)
        # rib_ch     = _norm(rib_map,     rib_max)
        rib_ch     = rib_bin
        disease_ch = _norm(disease_map, disease_max)

        cond = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32)  # (3,H,W)
        conditioning_pixel_values = torch.from_numpy(cond)

        enc = self.tokenizer(
            item['prompt'],
            max_length=self.tokenizer.model_max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            "conditioning_pixel_values": conditioning_pixel_values,  # (3,H,W) ∈ [0,1]
            "pixel_values": pixel_values,                            # (3,H,W) ∈ [-1,1]
            "input_ids": enc.input_ids.squeeze(0),
        }