Update dataset.py
Browse files- dataset.py +6 -3
dataset.py
CHANGED
|
@@ -12,8 +12,11 @@ class CXRMate2Dataset:
|
|
| 12 |
self.history = history
|
| 13 |
self.study_id_to_index = dict(zip(self.dataset['study_id'], range(len(self.dataset)), strict=True))
|
| 14 |
|
| 15 |
-
def __getitem__(self,
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
if 'views' not in batch:
|
| 19 |
batch['views'] = [None] * len(batch['images'])
|
|
@@ -26,7 +29,7 @@ class CXRMate2Dataset:
|
|
| 26 |
|
| 27 |
if self.history:
|
| 28 |
|
| 29 |
-
if
|
| 30 |
|
| 31 |
# Sort by datetime to ensure correct order:
|
| 32 |
assert all(i is not None and not (isinstance(i, float) and np.isnan(i)) for i in batch['prior_study_datetimes'])
|
|
|
|
| 12 |
self.history = history
|
| 13 |
self.study_id_to_index = dict(zip(self.dataset['study_id'], range(len(self.dataset)), strict=True))
|
| 14 |
|
| 15 |
+
def __getitem__(self, key):
|
| 16 |
+
if not isinstance(key, int):
|
| 17 |
+
return self.dataset[key]
|
| 18 |
+
|
| 19 |
+
batch = self.dataset[key]
|
| 20 |
|
| 21 |
if 'views' not in batch:
|
| 22 |
batch['views'] = [None] * len(batch['images'])
|
|
|
|
| 29 |
|
| 30 |
if self.history:
|
| 31 |
|
| 32 |
+
if batch['prior_study_ids']:
|
| 33 |
|
| 34 |
# Sort by datetime to ensure correct order:
|
| 35 |
assert all(i is not None and not (isinstance(i, float) and np.isnan(i)) for i in batch['prior_study_datetimes'])
|