|
|
import os
|
|
|
import json
|
|
|
import torch
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from pathlib import Path
|
|
|
from typing import List, Dict, Optional
|
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
'''
|
|
|
TARGET_CLASSES = [
|
|
|
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
|
|
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
|
|
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留", "切除术后",
|
|
|
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
|
|
]
|
|
|
'''
|
|
|
|
|
|
TARGET_CLASSES = [
|
|
|
"TI-RADS 1级", "TI-RADS 2级", "TI-RADS 3级", "TI-RADS 4a级",
|
|
|
"TI-RADS 4b级", "TI-RADS 4c级", "TI-RADS 5级",
|
|
|
"钙化", "甲亢", "囊肿", "淋巴结", "胶质潴留",
|
|
|
"弥漫性病变", "结节性甲状腺肿", "桥本氏甲状腺炎", "反应性", "转移性"
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RARE_CLASS_INDICES = [4, 5, 6, 16]
|
|
|
|
|
|
class ThyroidMultiLabelDataset(Dataset):
|
|
|
def __init__(self,
|
|
|
data_root: str,
|
|
|
annotation_csv: str,
|
|
|
split_json: Optional[str] = None,
|
|
|
split_type: str = 'train',
|
|
|
val_json_path: Optional[str] = None,
|
|
|
test_json_path: Optional[str] = None,
|
|
|
img_size: int = 224,
|
|
|
max_images_per_case: int = 20,
|
|
|
transform=None):
|
|
|
|
|
|
self.data_root = Path(data_root)
|
|
|
self.img_size = img_size
|
|
|
self.max_images_per_case = max_images_per_case
|
|
|
self.split_type = split_type
|
|
|
|
|
|
|
|
|
self.df_labels = pd.read_csv(annotation_csv)
|
|
|
|
|
|
self.df_labels.set_index('case_path', inplace=True)
|
|
|
|
|
|
|
|
|
self.case_list = self._get_split_cases(split_json, val_json_path, test_json_path)
|
|
|
|
|
|
|
|
|
if transform:
|
|
|
self.transform = transform
|
|
|
elif split_type == 'train':
|
|
|
self.transform = transforms.Compose([
|
|
|
transforms.Resize((img_size, img_size)),
|
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
|
transforms.RandomVerticalFlip(p=0.5),
|
|
|
transforms.RandomRotation(15),
|
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
])
|
|
|
else:
|
|
|
self.transform = transforms.Compose([
|
|
|
transforms.Resize((img_size, img_size)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
])
|
|
|
|
|
|
print(f"[{split_type.upper()}] Loaded {len(self.case_list)} cases.")
|
|
|
|
|
|
def _get_split_cases(self, split_json, val_json_path, test_json_path):
|
|
|
"""
|
|
|
根据 JSON 文件划分数据集
|
|
|
"""
|
|
|
all_cases_in_csv = set(self.df_labels.index.tolist())
|
|
|
|
|
|
|
|
|
target_cases = []
|
|
|
if split_json:
|
|
|
with open(split_json, 'r') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
target_cases = [item['rel_path'] for item in data]
|
|
|
|
|
|
|
|
|
valid_cases = [c for c in target_cases if c in all_cases_in_csv]
|
|
|
return valid_cases
|
|
|
|
|
|
|
|
|
elif self.split_type == 'train':
|
|
|
exclude_cases = set()
|
|
|
|
|
|
if val_json_path:
|
|
|
with open(val_json_path, 'r') as f:
|
|
|
exclude_cases.update([item['rel_path'] for item in json.load(f)])
|
|
|
|
|
|
if test_json_path:
|
|
|
with open(test_json_path, 'r') as f:
|
|
|
exclude_cases.update([item['rel_path'] for item in json.load(f)])
|
|
|
|
|
|
train_cases = list(all_cases_in_csv - exclude_cases)
|
|
|
return sorted(train_cases)
|
|
|
|
|
|
else:
|
|
|
return []
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.case_list)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
case_rel_path = self.case_list[idx]
|
|
|
|
|
|
|
|
|
img_dir = self.data_root / case_rel_path / "Images"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values.astype(np.float32)
|
|
|
label_tensor = torch.tensor(label_vec)
|
|
|
except KeyError:
|
|
|
print(f"Warning: Label for {case_rel_path} not found in CSV. Using zeros.")
|
|
|
label_tensor = torch.zeros(len(TARGET_CLASSES))
|
|
|
|
|
|
|
|
|
image_files = sorted(list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png")) + list(img_dir.glob("*.bmp")))
|
|
|
|
|
|
|
|
|
if self.max_images_per_case and len(image_files) > self.max_images_per_case:
|
|
|
if self.split_type == 'train':
|
|
|
|
|
|
image_files = np.random.choice(image_files, self.max_images_per_case, replace=False)
|
|
|
else:
|
|
|
image_files = image_files[:self.max_images_per_case]
|
|
|
|
|
|
images = []
|
|
|
for img_path in image_files:
|
|
|
try:
|
|
|
img = Image.open(img_path).convert('RGB')
|
|
|
if self.transform:
|
|
|
img = self.transform(img)
|
|
|
images.append(img)
|
|
|
except Exception as e:
|
|
|
pass
|
|
|
|
|
|
if len(images) == 0:
|
|
|
|
|
|
images = [torch.zeros(3, self.img_size, self.img_size)]
|
|
|
|
|
|
images_stack = torch.stack(images)
|
|
|
|
|
|
return {
|
|
|
'images': images_stack,
|
|
|
'labels': label_tensor,
|
|
|
'num_images': len(images),
|
|
|
'case_id': case_rel_path
|
|
|
}
|
|
|
|
|
|
def get_sampler_weights(self):
|
|
|
"""
|
|
|
计算采样权重:包含稀有类别的样本权重 = 10,其他 = 1
|
|
|
"""
|
|
|
weights = []
|
|
|
for case_rel_path in self.case_list:
|
|
|
label_vec = self.df_labels.loc[case_rel_path, TARGET_CLASSES].values
|
|
|
|
|
|
|
|
|
is_rare = False
|
|
|
for idx in RARE_CLASS_INDICES:
|
|
|
if label_vec[idx] == 1:
|
|
|
is_rare = True
|
|
|
break
|
|
|
|
|
|
if is_rare:
|
|
|
weights.append(10.0)
|
|
|
else:
|
|
|
weights.append(1.0)
|
|
|
|
|
|
return torch.tensor(weights, dtype=torch.double)
|
|
|
|
|
|
def collate_fn(batch):
|
|
|
images_list = []
|
|
|
labels_list = []
|
|
|
num_instances_list = []
|
|
|
case_ids = []
|
|
|
|
|
|
for item in batch:
|
|
|
images_list.append(item['images'])
|
|
|
labels_list.append(item['labels'])
|
|
|
num_instances_list.append(item['num_images'])
|
|
|
case_ids.append(item['case_id'])
|
|
|
|
|
|
all_images = torch.cat(images_list, dim=0)
|
|
|
labels = torch.stack(labels_list)
|
|
|
num_instances_per_case = torch.tensor(num_instances_list, dtype=torch.long)
|
|
|
|
|
|
return {
|
|
|
'images': all_images,
|
|
|
'labels': labels,
|
|
|
'num_instances_per_case': num_instances_per_case,
|
|
|
'case_ids': case_ids
|
|
|
}
|
|
|
|
|
|
def create_dataloaders(config):
|
|
|
data_root = config['data']['data_root']
|
|
|
csv_path = config['data']['annotation_csv']
|
|
|
val_json = config['data']['val_json']
|
|
|
test_json = config['data']['test_json']
|
|
|
|
|
|
|
|
|
train_dataset = ThyroidMultiLabelDataset(
|
|
|
data_root=data_root,
|
|
|
annotation_csv=csv_path,
|
|
|
split_type='train',
|
|
|
val_json_path=val_json,
|
|
|
test_json_path=test_json,
|
|
|
img_size=config['data']['img_size'],
|
|
|
max_images_per_case=config['data']['max_images_per_case']
|
|
|
)
|
|
|
|
|
|
|
|
|
print("Calculating sampler weights for class balance...")
|
|
|
train_weights = train_dataset.get_sampler_weights()
|
|
|
sampler = WeightedRandomSampler(train_weights, len(train_weights))
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=config['training']['batch_size'],
|
|
|
sampler=sampler,
|
|
|
num_workers=config['data']['num_workers'],
|
|
|
collate_fn=collate_fn,
|
|
|
pin_memory=True,
|
|
|
drop_last=True
|
|
|
)
|
|
|
|
|
|
|
|
|
val_dataset = ThyroidMultiLabelDataset(
|
|
|
data_root=data_root,
|
|
|
annotation_csv=csv_path,
|
|
|
split_type='val',
|
|
|
split_json=val_json,
|
|
|
img_size=config['data']['img_size'],
|
|
|
max_images_per_case=config['data']['max_images_per_case']
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=config['training']['batch_size'],
|
|
|
shuffle=False,
|
|
|
num_workers=config['data']['num_workers'],
|
|
|
collate_fn=collate_fn,
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
|
|
|
test_dataset = ThyroidMultiLabelDataset(
|
|
|
data_root=data_root,
|
|
|
annotation_csv=csv_path,
|
|
|
split_type='test',
|
|
|
split_json=test_json,
|
|
|
img_size=config['data']['img_size'],
|
|
|
max_images_per_case=None
|
|
|
)
|
|
|
|
|
|
test_loader = DataLoader(
|
|
|
test_dataset,
|
|
|
batch_size=config['training']['batch_size'],
|
|
|
shuffle=False,
|
|
|
num_workers=config['data']['num_workers'],
|
|
|
collate_fn=collate_fn,
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
return train_loader, val_loader, test_loader |