# MyDataset_rib.py import os import json import numpy as np import torch from torch.utils.data import Dataset from PIL import Image DISEASES = [ "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", "Diffuse Nodule", "Edema", "Effusion", "Emphysema", "Enlarged Cardiomediastinum", "Fibrosis", "Fracture", "Mass", "Nodule", "Opacity", "Pleural Thickening", "Pneumothorax", "Pacemaker", ] DISEASES_LOWER = {d.lower() for d in DISEASES} def _resolve_path(rel_or_abs, base_dir): if not rel_or_abs: return None return rel_or_abs if os.path.isabs(rel_or_abs) else os.path.join(base_dir, rel_or_abs) def _load_gray_int(path): arr = np.array(Image.open(path).convert("L")) return arr.astype(np.int32) def _load_bin_mask_resize(path, hw): H, W = hw im = Image.open(path).convert("L") if im.size != (W, H): im = im.resize((W, H), resample=Image.NEAREST) arr = np.array(im) return (arr > 0).astype(np.float32) def _remap_to_sequential(label_map_int): vals = np.unique(label_map_int) vals = vals[vals != 0] if len(vals) == 0: return label_map_int.astype(np.int32), 0 out = np.zeros_like(label_map_int, dtype=np.int32) for i, v in enumerate(vals, start=1): out[label_map_int == v] = i return out, int(out.max()) def _norm(ch, maxv): return ch.astype(np.float32) / float(maxv) if maxv > 0 else ch.astype(np.float32) def _get_first_disease_name(item): attn_list = item.get("attn_list", []) if ( len(attn_list) > 0 and isinstance(attn_list[0], (list, tuple)) and len(attn_list[0]) >= 1 ): return str(attn_list[0][0]).strip() return None def _is_target_record(item): disease_name = _get_first_disease_name(item) if disease_name is None: return False return disease_name.lower() in DISEASES_LOWER class MyDataset_rib(Dataset): def __init__(self, args, tokenizer): self.data_dir = args.train_data_dir self.prompt_dir = args.train_data_prompt self.tokenizer = tokenizer with open(self.prompt_dir, "rt") as f: raw_data = [json.loads(l) for l in f if l.strip()] self.data = [item for item in raw_data if _is_target_record(item)] print(f"Loaded {len(self.data)} valid records from {len(raw_data)} total records.") def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] organ_path = _resolve_path(item.get("organ", ""), self.data_dir) rib_path = _resolve_path(item.get("rib", ""), self.data_dir) mask_path = _resolve_path(item.get("mask", ""), self.data_dir) # organ organ_map_raw = _load_gray_int(organ_path) organ_map, organ_max = _remap_to_sequential(organ_map_raw) # rib rib_map_raw = _load_gray_int(rib_path) rib_map, rib_max = _remap_to_sequential(rib_map_raw) H, W = organ_map.shape # disease channel: binary 0/1 disease_ch = np.zeros((H, W), dtype=np.float32) if mask_path is not None and os.path.exists(mask_path): disease_ch = _load_bin_mask_resize(mask_path, (H, W)).astype(np.float32) organ_ch = _norm(organ_map, organ_max) rib_ch = _norm(rib_map, rib_max) empty_ch = np.zeros((H, W), dtype=np.float32) original = np.stack([empty_ch, organ_ch, rib_ch], axis=0).astype(np.float32) edited = np.stack([disease_ch, organ_ch, rib_ch], axis=0).astype(np.float32) enc = self.tokenizer( item["prompt"], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ) return { "original_pixel_values": torch.from_numpy(original), "edited_pixel_values": torch.from_numpy(edited), "input_ids": enc.input_ids.squeeze(0), }