cxrmate-2 / dataset.py
anicolson's picture
Update dataset.py
ce545f0 verified
from datetime import datetime
from math import e
from typing import List
import numpy as np
import torch
class CXRMate2Dataset:
def __init__(self, dataset, history=1):
self.dataset = dataset
self.history = history
self.study_id_to_index = dict(zip(self.dataset['study_id'], range(len(self.dataset)), strict=True))
def __getitem__(self, key):
if not isinstance(key, int):
return self.dataset[key]
batch = self.dataset[key]
if 'views' not in batch:
batch['views'] = [None] * len(batch['images'])
# Set None study_datetimes to a default value:
batch['study_datetime'] = datetime(1, 1, 1, 0, 0) if batch['study_datetime'] is None else batch['study_datetime']
# Datetime for current study:
batch['image_datetime'] = [batch['study_datetime'] for _ in batch['images']]
if self.history:
if batch['prior_study_ids']:
# Sort by datetime to ensure correct order:
assert all(i is not None and not (isinstance(i, float) and np.isnan(i)) for i in batch['prior_study_datetimes'])
prior_study_ids = [i for _, i in sorted(zip(batch['prior_study_datetimes'], batch['prior_study_ids'], strict=True))]
prior_study_ids = prior_study_ids[-self.history:]
# prior_study_datetimes = sorted(batch['prior_study_datetimes'])[-self.history:]
prior_study_indices = [self.study_id_to_index[i] for i in prior_study_ids]
prior_studies = [self.dataset[i] for i in prior_study_indices]
# Datetime of prior studies:
batch['prior_study_datetime'] = [i['study_datetime'] for i in prior_studies]
# Add prior images and their datetime:
for study in prior_studies:
if 'views' not in study:
study['views'] = [None] * len(study['images'])
for image, view in zip(study['images'], study['views'], strict=True):
batch['images'].insert(0, image)
batch['views'].insert(0, view)
batch['image_datetime'].insert(0, study['study_datetime'])
# Prior findings and impressions:
batch['prior_findings'] = [None if i is None else i['findings'] for i in prior_studies]
batch['prior_impression'] = [
None if i is None else i['impression'] for i in prior_studies
]
else:
batch['prior_study_datetime'] = [None]
batch['prior_findings'] = [None]
batch['prior_impression'] = [None]
return batch
def __len__(self):
return len(self.dataset)
def __getattr__(self, name):
return getattr(self.dataset, name)
def __getitems__(self, keys: List):
batch = [self.__getitem__(key) for key in keys]
keys = set().union(*(d.keys() for d in batch))
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
return batch