File size: 3,347 Bytes
e1aa279
 
 
 
 
 
 
 
 
 
 
 
 
 
ce545f0
 
 
 
 
e1aa279
 
 
 
 
 
 
 
 
 
 
 
ce545f0
e1aa279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from datetime import datetime
from math import e
from typing import List

import numpy as np
import torch


class CXRMate2Dataset:
    def __init__(self, dataset, history=1):
        self.dataset = dataset
        self.history = history
        self.study_id_to_index = dict(zip(self.dataset['study_id'], range(len(self.dataset)), strict=True))

    def __getitem__(self, key):
        if not isinstance(key, int):
            return self.dataset[key]

        batch = self.dataset[key]

        if 'views' not in batch:
            batch['views'] = [None] * len(batch['images'])
        
        # Set None study_datetimes to a default value:
        batch['study_datetime'] = datetime(1, 1, 1, 0, 0) if batch['study_datetime'] is None else batch['study_datetime']

        # Datetime for current study:
        batch['image_datetime'] = [batch['study_datetime'] for _ in batch['images']] 

        if self.history:
            
            if batch['prior_study_ids']:
            
                # Sort by datetime to ensure correct order:
                assert all(i is not None and not (isinstance(i, float) and np.isnan(i)) for i in batch['prior_study_datetimes'])
                prior_study_ids = [i for _, i in sorted(zip(batch['prior_study_datetimes'], batch['prior_study_ids'], strict=True))]
                prior_study_ids = prior_study_ids[-self.history:]
                
                # prior_study_datetimes = sorted(batch['prior_study_datetimes'])[-self.history:]
                prior_study_indices = [self.study_id_to_index[i] for i in prior_study_ids]
                prior_studies = [self.dataset[i] for i in prior_study_indices]

                # Datetime of prior studies:
                batch['prior_study_datetime'] = [i['study_datetime'] for i in prior_studies]
                            
                # Add prior images and their datetime:
                for study in prior_studies:

                    if 'views' not in study:
                        study['views'] = [None] * len(study['images'])

                    for image, view in zip(study['images'], study['views'], strict=True):
                        batch['images'].insert(0, image)
                        batch['views'].insert(0, view)
                        batch['image_datetime'].insert(0, study['study_datetime'])     
                
                # Prior findings and impressions:
                batch['prior_findings'] = [None if i is None else i['findings'] for i in prior_studies]     
                batch['prior_impression'] = [
                    None if i is None else i['impression'] for i in prior_studies
                ]    
            else:
                batch['prior_study_datetime'] = [None]
                batch['prior_findings'] = [None]
                batch['prior_impression'] = [None] 

        return batch

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)
    
    def __getitems__(self, keys: List):
        batch = [self.__getitem__(key) for key in keys]

        keys = set().union(*(d.keys() for d in batch))
        batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
        batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}

        return batch