anicolson commited on
Commit
e1aa279
·
verified ·
1 Parent(s): 6648ce8

Upload processor

Browse files
Files changed (2) hide show
  1. dataset.py +80 -0
  2. processing_cxrmate2.py +8 -0
dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from math import e
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class CXRMate2Dataset:
10
+ def __init__(self, dataset, history=1):
11
+ self.dataset = dataset
12
+ self.history = history
13
+ self.study_id_to_index = dict(zip(self.dataset['study_id'], range(len(self.dataset)), strict=True))
14
+
15
+ def __getitem__(self, idx):
16
+ batch = self.dataset[idx]
17
+
18
+ if 'views' not in batch:
19
+ batch['views'] = [None] * len(batch['images'])
20
+
21
+ # Set None study_datetimes to a default value:
22
+ batch['study_datetime'] = datetime(1, 1, 1, 0, 0) if batch['study_datetime'] is None else batch['study_datetime']
23
+
24
+ # Datetime for current study:
25
+ batch['image_datetime'] = [batch['study_datetime'] for _ in batch['images']]
26
+
27
+ if self.history:
28
+
29
+ if batch['prior_study_ids'] is not None:
30
+
31
+ # Sort by datetime to ensure correct order:
32
+ assert all(i is not None and not (isinstance(i, float) and np.isnan(i)) for i in batch['prior_study_datetimes'])
33
+ prior_study_ids = [i for _, i in sorted(zip(batch['prior_study_datetimes'], batch['prior_study_ids'], strict=True))]
34
+ prior_study_ids = prior_study_ids[-self.history:]
35
+
36
+ # prior_study_datetimes = sorted(batch['prior_study_datetimes'])[-self.history:]
37
+ prior_study_indices = [self.study_id_to_index[i] for i in prior_study_ids]
38
+ prior_studies = [self.dataset[i] for i in prior_study_indices]
39
+
40
+ # Datetime of prior studies:
41
+ batch['prior_study_datetime'] = [i['study_datetime'] for i in prior_studies]
42
+
43
+ # Add prior images and their datetime:
44
+ for study in prior_studies:
45
+
46
+ if 'views' not in study:
47
+ study['views'] = [None] * len(study['images'])
48
+
49
+ for image, view in zip(study['images'], study['views'], strict=True):
50
+ batch['images'].insert(0, image)
51
+ batch['views'].insert(0, view)
52
+ batch['image_datetime'].insert(0, study['study_datetime'])
53
+
54
+ # Prior findings and impressions:
55
+ batch['prior_findings'] = [None if i is None else i['findings'] for i in prior_studies]
56
+ batch['prior_impression'] = [
57
+ None if i is None else i['impression'] for i in prior_studies
58
+ ]
59
+ else:
60
+ batch['prior_study_datetime'] = [None]
61
+ batch['prior_findings'] = [None]
62
+ batch['prior_impression'] = [None]
63
+
64
+ return batch
65
+
66
+ def __len__(self):
67
+ return len(self.dataset)
68
+
69
+ def __getattr__(self, name):
70
+ return getattr(self.dataset, name)
71
+
72
+ def __getitems__(self, keys: List):
73
+ batch = [self.__getitem__(key) for key in keys]
74
+
75
+ keys = set().union(*(d.keys() for d in batch))
76
+ batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
77
+ batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
78
+
79
+ return batch
80
+
processing_cxrmate2.py CHANGED
@@ -15,6 +15,11 @@ from transformers.image_utils import ImageInput
15
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
  from utils import compute_time_delta
17
 
 
 
 
 
 
18
  # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
19
  VIEW_ORDER = [
20
  None,
@@ -564,3 +569,6 @@ class CXRMate2Processor(transformers.ProcessorMixin):
564
 
565
 
566
  return batch
 
 
 
 
15
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
  from utils import compute_time_delta
17
 
18
+ try:
19
+ from .dataset import CXRMate2Dataset
20
+ except ImportError:
21
+ from dataset import CXRMate2Dataset
22
+
23
  # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
24
  VIEW_ORDER = [
25
  None,
 
569
 
570
 
571
  return batch
572
+
573
+ def wrap_dataset(self, dataset):
574
+ return CXRMate2Dataset(dataset)