anicolson commited on
Commit
6648ce8
·
verified ·
1 Parent(s): ce40000

Upload processor

Browse files
preprocessor_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
 
 
 
2
  "crop_size": {
3
  "height": 518,
4
  "width": 518
 
1
  {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_cxrmate2.CXRMate2Processor"
4
+ },
5
  "crop_size": {
6
  "height": 518,
7
  "width": 518
processing_cxrmate2.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import random
4
+ from cProfile import label
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import transformers
11
+ from PIL import Image
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.image_utils import ImageInput
15
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
+ from utils import compute_time_delta
17
+
18
+ # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
19
+ VIEW_ORDER = [
20
+ None,
21
+ 'nan', # PadChest.
22
+ 'SWIMMERS',
23
+ 'LPO',
24
+ 'RAO',
25
+ 'LAO',
26
+ 'OBLICUA', # PadChest.
27
+ 'AP LLD',
28
+ 'AP RLD',
29
+ 'PA LLD',
30
+ 'PA RLD',
31
+ 'LLD', # PadChest.
32
+ 'XTABLE LATERAL',
33
+ 'RL',
34
+ 'LL',
35
+ 'Lateral',
36
+ 'LATERAL',
37
+ 'AP AXIAL',
38
+ 'ANTEROPOSTERIOR', # PadChest.
39
+ 'AP',
40
+ 'GENERICA', # PadChest (PA).
41
+ 'POSTEROANTERIOR', # PadChest.
42
+ 'PA',
43
+ ]
44
+
45
+
46
+ class CXRMate2Processor(transformers.ProcessorMixin):
47
+
48
+ attributes = ['image_processor', 'tokenizer']
49
+ image_processor_class = 'AutoImageProcessor'
50
+ tokenizer_class = 'AutoTokenizer'
51
+ valid_kwargs = [
52
+ 'token_type_to_token_type_id',
53
+ 'max_generated_tokens',
54
+ ]
55
+
56
+ def __init__(
57
+ self,
58
+ image_processor,
59
+ tokenizer,
60
+ token_type_to_token: Dict[str, int],
61
+ max_generated_tokens: int,
62
+ embeddings_per_image: int,
63
+ image_token: str,
64
+ max_train_images_per_study: int, # This includes current and prior images.
65
+ generate_findings_token: str,
66
+ generate_impression_token: str,
67
+ convert_to_rgb: bool = False,
68
+ **kwargs,
69
+ ):
70
+ super().__init__(image_processor, tokenizer)
71
+
72
+ self.token_type_to_token = token_type_to_token
73
+ self.max_generated_tokens = max_generated_tokens
74
+ self.embeddings_per_image = embeddings_per_image
75
+ self.image_token = image_token
76
+ self.max_train_images_per_study = max_train_images_per_study
77
+
78
+ self.generate_findings_token = generate_findings_token
79
+ self.generate_impression_token = generate_impression_token
80
+
81
+ self.convert_to_rgb = convert_to_rgb
82
+
83
+ self.generate_findings_token_id = self.tokenizer.convert_tokens_to_ids(self.generate_findings_token)
84
+ self.generate_impression_token_id = self.tokenizer.convert_tokens_to_ids(self.generate_impression_token)
85
+
86
+ self.time_delta_map = lambda x: 1 / math.sqrt((x / 3600) + 1)
87
+ self.time_delta_monotonic_inversion = True
88
+ self.zero_time_delta_value = self.time_delta_map(0.0)
89
+ self.inf_time_delta_value = self.time_delta_map(float('inf'))
90
+
91
+ self.prior_section_token_type_ids = [self.tokenizer.convert_tokens_to_ids(self.token_type_to_token[i]) for i in ['prior_findings', 'prior_impression']]
92
+ self.section_token_type_ids = [self.tokenizer.convert_tokens_to_ids(self.token_type_to_token[i]) for i in ['indication', 'history', 'comparison', 'technique']]
93
+
94
+ assert self.tokenizer.bos_token_id is not None, 'Tokenizer must have a bos_token_id.'
95
+ assert self.tokenizer.sep_token_id is not None, 'Tokenizer must have a sep_token_id.'
96
+ assert self.tokenizer.eos_token_id is not None, 'Tokenizer must have a eos_token_id.'
97
+ assert self.tokenizer.pad_token_id is not None, 'Tokenizer must have a pad_token_id.'
98
+
99
+ def __call__(
100
+ self,
101
+ images: ImageInput,
102
+ image_datetime: Union[List[float], None] = None,
103
+ findings: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
104
+ impression: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
105
+ views: Union[List[str]] = None,
106
+ indication: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
107
+ history: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
108
+ comparison: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
109
+ technique: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
110
+
111
+ study_datetime: Union[float, None] = None,
112
+
113
+ # Priors:
114
+ prior_findings: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
115
+ prior_impression: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], None] = None,
116
+ prior_study_datetime: Union[List[float], None] = None,
117
+
118
+ train: bool = False,
119
+ **kwargs,
120
+ ) -> BatchFeature:
121
+
122
+ batch_size = len(images)
123
+
124
+ if views is None:
125
+ views = [[None for _, _ in enumerate(i)] for i in images]
126
+
127
+ batch = {
128
+ 'input_ids': {i: [] for i in range(batch_size)},
129
+ 'token_type_ids': {i: [] for i in range(batch_size)},
130
+ 'time_deltas': {i: [] for i in range(batch_size)},
131
+ 'time_deltas_mask': {i: [] for i in range(batch_size)},
132
+ 'attention_mask': [],
133
+ }
134
+
135
+ non_causal_2d_attention_mask = {i: [] for i in range(batch_size)}
136
+ causal_2d_attention_mask = []
137
+
138
+ # Map the prior study time delta values using the time delta map:
139
+ if prior_study_datetime is not None:
140
+ prior_study_time_deltas = [
141
+ [self.time_delta_map(compute_time_delta(j, k)) if j is not None else float('nan') for j in i] for i, k in zip(prior_study_datetime, study_datetime, strict=True)
142
+ ]
143
+
144
+ # Findings and impression sections from prior studies:
145
+ for i, token_type_id in zip([prior_findings, prior_impression], self.prior_section_token_type_ids, strict=True):
146
+ if not i:
147
+ continue
148
+ assert len(i) == batch_size, f'Length of {i} must be equal to the batch size: {batch_size}.'
149
+ for j in range(len(i)):
150
+ if not i[j]:
151
+ continue
152
+ for k in range(len(i[j])):
153
+ if not i[j][k]:
154
+ continue
155
+ batch['input_ids'][j].append(self.tokenizer.encode(i[j][k], add_special_tokens=False, return_tensors='pt')[0])
156
+ batch['token_type_ids'][j].append(torch.full((batch['input_ids'][j][-1].shape[-1],), token_type_id, dtype=torch.long))
157
+ non_causal_2d_attention_mask[j].append((batch['input_ids'][j][-1] != self.tokenizer.pad_token_id).long())
158
+ batch['time_deltas'][j].append(
159
+ torch.full(
160
+ (batch['input_ids'][j][-1].shape[-1],),
161
+ prior_study_time_deltas[j][k] if prior_study_time_deltas is not None and prior_study_time_deltas[j][k] is not None else float('nan'),
162
+ dtype=torch.float32,
163
+ ),
164
+ )
165
+ batch['time_deltas_mask'][j].append(torch.full((batch['input_ids'][j][-1].shape[-1],), 1.0, dtype=torch.float32))
166
+
167
+ # Sections of the report for the prompt:
168
+ for i, token_type_id in zip([indication, history, comparison, technique], self.section_token_type_ids, strict=True):
169
+ if not i:
170
+ continue
171
+ assert len(i) == batch_size, f'Length of {i} must be equal to the batch size: {batch_size}.'
172
+ for j, k in enumerate(i):
173
+ if not k:
174
+ continue
175
+ batch['input_ids'][j].append(self.tokenizer.encode(k, add_special_tokens=False, return_tensors='pt')[0])
176
+ batch['token_type_ids'][j].append(torch.full((batch['input_ids'][j][-1].shape[-1],), token_type_id, dtype=torch.long))
177
+ non_causal_2d_attention_mask[j].append((batch['input_ids'][j][-1] != self.tokenizer.pad_token_id).long())
178
+ batch['time_deltas'][j].append(
179
+ torch.full((batch['input_ids'][j][-1].shape[-1],), self.zero_time_delta_value, dtype=torch.float32),
180
+ )
181
+ batch['time_deltas_mask'][j].append(torch.full((batch['input_ids'][j][-1].shape[-1],), 1.0, dtype=torch.float32))
182
+
183
+ # Labels; findings and impression:
184
+ if train:
185
+ batch['label_ids'] = []
186
+ for i, (j, k) in enumerate(zip(findings, impression, strict=True)):
187
+
188
+ if j is not None and k is not None:
189
+ report = f'{self.tokenizer.bos_token}{j}{self.tokenizer.sep_token}{k}{self.tokenizer.eos_token}'
190
+ elif j is not None and k is None:
191
+ report = f'{self.generate_findings_token}{j}{self.tokenizer.eos_token}'
192
+ elif j is None and k is not None:
193
+ report = f'{self.generate_impression_token}{k}{self.tokenizer.eos_token}'
194
+ else:
195
+ raise ValueError('Both findings and impression cannot be None.')
196
+
197
+ report_ids = self.tokenizer.encode(
198
+ report,
199
+ truncation=True,
200
+ max_length=self.max_generated_tokens + 1, # +1 to account for the bias between input and target.
201
+ return_tensors='pt',
202
+ add_special_tokens=False,
203
+ )[0]
204
+
205
+ # Labels for the decoder (shifted right by one for autoregression):
206
+ batch['label_ids'].append(report_ids[1:].clone())
207
+
208
+ # Remove last token identifier to match the sequence length of the labels:
209
+ batch['input_ids'][i].append(report_ids[:-1])
210
+
211
+ report_token_type_ids = self.token_ids_to_token_type_ids(token_ids=batch['input_ids'][i][-1])
212
+ batch['token_type_ids'][i].append(report_token_type_ids)
213
+
214
+ causal_2d_attention_mask.append((batch['input_ids'][i][-1] != self.tokenizer.pad_token_id).long())
215
+
216
+ batch['time_deltas'][i].append(
217
+ torch.full((batch['input_ids'][i][-1].shape[-1],), self.zero_time_delta_value, dtype=torch.float32),
218
+ )
219
+
220
+ batch['time_deltas_mask'][i].append(torch.full((batch['input_ids'][i][-1].shape[-1],), 0.0, dtype=torch.float32))
221
+
222
+ else: # Add special tokens for generation:
223
+ for i in range(batch_size):
224
+
225
+ bos_token_id = self.tokenizer.bos_token_id
226
+ batch['token_type_ids'][i].append(torch.tensor([self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['findings'])], dtype=torch.long))
227
+
228
+ batch['input_ids'][i].append(torch.tensor([bos_token_id], dtype=torch.long))
229
+
230
+ causal_2d_attention_mask.append(torch.tensor([1], dtype=torch.long))
231
+
232
+ batch['time_deltas'][i].append(torch.tensor([self.zero_time_delta_value], dtype=torch.float32))
233
+ batch['time_deltas_mask'][i].append(torch.tensor([0.0], dtype=torch.float32))
234
+
235
+ # Map the image time delta values using the time delta map:
236
+ image_time_deltas = [[self.time_delta_map(compute_time_delta(j, k)) if j is not None else float('nan') for j in i] for i, k in zip(image_datetime, study_datetime, strict=True)]
237
+
238
+ # Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
239
+ for i in range(len(images)):
240
+ if len(images[i]) > self.max_train_images_per_study:
241
+ paired = list(zip(images[i], views[i], image_time_deltas[i], strict=True))
242
+ sampled_pairs = random.sample(paired, self.max_train_images_per_study)
243
+ images[i], views[i], image_time_deltas[i] = map(list, zip(*sampled_pairs, strict=True))
244
+
245
+ # Sort based on views:
246
+ images, views, image_time_deltas = self.sort_images(images, views, image_time_deltas)
247
+
248
+ # Images:
249
+ max_images = max(len(i) for i in images)
250
+ for i in range(batch_size):
251
+ for j in range(max_images):
252
+ if j < len(images[i]):
253
+ if isinstance(images[i][j], bytes):
254
+ image = Image.open(io.BytesIO(images[i][j]))
255
+ if self.convert_to_rgb:
256
+ image = image.convert('RGB')
257
+ images[i][j] = self.image_processor(image, return_tensors='pt')['pixel_values'].squeeze(0)
258
+
259
+ batch['time_deltas'][i].insert(j, torch.full((self.embeddings_per_image,), image_time_deltas[i][j]))
260
+ batch['time_deltas_mask'][i].insert(j, torch.full((self.embeddings_per_image,), 1.0))
261
+
262
+ token_type_id = self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['image']) if image_time_deltas[i][j] == self.zero_time_delta_value else self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['prior_image'])
263
+ batch['token_type_ids'][i].insert(j, torch.full((self.embeddings_per_image,), token_type_id))
264
+
265
+ non_causal_2d_attention_mask[i].insert(j, torch.full((self.embeddings_per_image,), 1))
266
+
267
+ else:
268
+
269
+ batch['time_deltas'][i].insert(j, torch.full((self.embeddings_per_image,), 0.0))
270
+ batch['time_deltas_mask'][i].insert(j, torch.full((self.embeddings_per_image,), 0.0))
271
+
272
+ batch['token_type_ids'][i].insert(j, torch.full((self.embeddings_per_image,), self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['image'])))
273
+
274
+ non_causal_2d_attention_mask[i].insert(j, torch.full((self.embeddings_per_image,), 0))
275
+
276
+
277
+ images[i] = torch.stack(images[i])
278
+ batch['input_ids'][i].insert(0, self.tokenizer.encode(self.image_token * self.embeddings_per_image * max_images, add_special_tokens=False, return_tensors='pt')[0])
279
+
280
+ batch['pixel_values'] = pad_sequence(images, batch_first=True, padding_value=0.0)
281
+
282
+ # Concatenate input_ids, token_type_ids, time_deltas, and time_deltas_mask:
283
+ batch['input_ids'] = [torch.cat(j, dim=0) for j in batch['input_ids'].values()]
284
+ batch['token_type_ids'] = [torch.cat(j, dim=0) for j in batch['token_type_ids'].values()]
285
+ batch['time_deltas'] = [torch.cat(j, dim=0) for j in batch['time_deltas'].values()]
286
+ batch['time_deltas_mask'] = [torch.cat(j, dim=0) for j in batch['time_deltas_mask'].values()]
287
+
288
+ # Concatentate, and convert label_ids into padded sequences:
289
+ if train:
290
+ batch['label_ids'] = [F.pad(i, (len(j) - len(i), 0), 'constant', self.tokenizer.pad_token_id) for i, j in zip(batch['label_ids'], batch['input_ids'], strict=True)]
291
+ batch['label_ids'] = pad_sequence(batch['label_ids'], batch_first=True, padding_value=self.tokenizer.pad_token_id)
292
+
293
+ # Convert input_ids, token_type_ids, time_deltas, and time_deltas_mask into padded sequences:
294
+ batch['input_ids'] = pad_sequence(batch['input_ids'], batch_first=True, padding_value=self.tokenizer.pad_token_id)
295
+ batch['token_type_ids'] = pad_sequence(batch['token_type_ids'], batch_first=True, padding_value=0)
296
+ batch['time_deltas'] = pad_sequence(batch['time_deltas'], batch_first=True, padding_value=0)
297
+ batch['time_deltas_mask'] = pad_sequence(batch['time_deltas_mask'], batch_first=True, padding_value=0)
298
+
299
+ # Assert that time_delta values are between zero_time_delta_value and inf_time_delta_value:
300
+ check_1 = torch.all((batch['time_deltas'][~torch.isnan(batch['time_deltas'])] <= max([self.zero_time_delta_value, self.inf_time_delta_value])))
301
+ check_2 = torch.all((batch['time_deltas'][~torch.isnan(batch['time_deltas'])] >= min([self.zero_time_delta_value, self.inf_time_delta_value])))
302
+ assert check_1 & check_2, 'Time delta values must be between zero_time_delta_value and inf_time_delta_value, or NaN if the time delta is missing.'
303
+
304
+ # Mixed causality mask:
305
+ non_causal_2d_attention_mask = [torch.cat(j, dim=0) for j in non_causal_2d_attention_mask.values()]
306
+ batch['attention_mask'] = self.create_4d_mixed_causality_attention_mask(
307
+ non_causal_2d_attention_mask,
308
+ causal_2d_attention_mask,
309
+ dtype=batch['pixel_values'].dtype,
310
+ )
311
+
312
+ if not train:
313
+ batch['initial_attention_mask'] = batch['attention_mask'].clone() # For the first iteration of generation.
314
+ batch['attention_mask'] = (batch['attention_mask'].squeeze(1).diagonal(dim1=1, dim2=2) == 0.0).long()
315
+
316
+ # Create position_ids from time_deltas and attention_mask:
317
+ batch['position_ids'] = self.position_ids_from_time_deltas_and_attention_mask(batch['time_deltas'], batch['attention_mask'])
318
+
319
+ rows, cols = (batch['input_ids'] == self.tokenizer.sep_token_id).nonzero(as_tuple=True)
320
+ assert all(batch['token_type_ids'][rows, cols] == self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['findings']))
321
+
322
+ rows, cols = (batch['input_ids'] == self.tokenizer.bos_token_id).nonzero(as_tuple=True)
323
+ assert all(batch['token_type_ids'][rows, cols] == self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['findings']))
324
+
325
+ return BatchFeature(data=batch)
326
+
327
+ @staticmethod
328
+ def sort_images(images, views, image_time_deltas):
329
+ def sort_by_view(images, views, time_deltas):
330
+ paired = list(zip(images, views, time_deltas, strict=True))
331
+ sorted_pairs = sorted(paired, key=lambda x: VIEW_ORDER.index(x[1]))
332
+ sorted_images, sorted_views, sorted_time_deltas = map(list, zip(*sorted_pairs, strict=True))
333
+ return sorted_images, sorted_views, sorted_time_deltas
334
+
335
+ # Apply sorting to each set of images, views, and time deltas:
336
+ sorted_results = [sort_by_view(i, j, k) for i, j, k in zip(images, views, image_time_deltas, strict=True)]
337
+
338
+ sorted_images = [result[0] for result in sorted_results]
339
+ sorted_views = [result[1] for result in sorted_results]
340
+ sorted_time_deltas = [result[2] for result in sorted_results]
341
+
342
+ return sorted_images, sorted_views, sorted_time_deltas
343
+
344
+ def token_ids_to_token_type_ids(self, token_ids, num_report_tokens=None):
345
+ findings_id = self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['findings'])
346
+ impression_id = self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['impression'])
347
+ sep_id = self.tokenizer.sep_token_id
348
+
349
+ # Initialize all as 'findings':
350
+ token_type_ids = torch.full_like(token_ids, findings_id)
351
+
352
+ # Detect SEP positions:
353
+ sep_positions = (token_ids == sep_id).nonzero(as_tuple=True)[0] # 1-D tensor of indices.
354
+
355
+ if sep_positions.numel() > 0:
356
+ # Use the first [SEP] as the split point; change anything after it to 'impression' (this is fine as it will be treated as invalid for RL):
357
+ first_sep = sep_positions[0].item()
358
+ if first_sep + 1 < token_type_ids.numel():
359
+ token_type_ids[first_sep + 1:] = impression_id
360
+
361
+ return token_type_ids if num_report_tokens is None else token_type_ids[-num_report_tokens:]
362
+
363
+ def create_4d_mixed_causality_attention_mask(self, non_causal_attention_mask, causal_attention_mask, dtype=torch.float32):
364
+ attention_mask = []
365
+
366
+ max_len = max([len(i) + len(j) for i, j in zip(non_causal_attention_mask, causal_attention_mask, strict=True)])
367
+
368
+ for i in range(len(non_causal_attention_mask)):
369
+ attention_mask.append(
370
+ self.create_3d_mixed_causality_attention_mask(
371
+ non_causal_attention_mask[i],
372
+ causal_attention_mask[i],
373
+ dtype=dtype,
374
+ )
375
+ )
376
+ pad_len = max_len - attention_mask[-1].shape[-1]
377
+ attention_mask[-1] = F.pad(attention_mask[-1], (0, pad_len, 0, pad_len, 0, 0), 'constant', torch.finfo(dtype).min)
378
+ attention_mask = torch.stack(attention_mask)
379
+
380
+ return attention_mask
381
+
382
+ @staticmethod
383
+ def create_3d_mixed_causality_attention_mask(non_causal_1d_attention_mask, causal_1d_attention_mask, dtype=torch.float32):
384
+
385
+ # Expand to 2D (seq_len x seq_len):
386
+ upper_left = non_causal_1d_attention_mask[:, None] * non_causal_1d_attention_mask[None, :]
387
+
388
+ if causal_1d_attention_mask is not None:
389
+
390
+ prompt_seq_len = non_causal_1d_attention_mask.shape[-1]
391
+ report_seq_len = causal_1d_attention_mask.shape[-1]
392
+
393
+ # Lower right of attention matrix (causal attention with lower triangular masking):
394
+ causal_mask = torch.tril(torch.ones(report_seq_len, report_seq_len, device=causal_1d_attention_mask.device))
395
+ lower_right = causal_1d_attention_mask[:, None] * causal_1d_attention_mask[None, :]
396
+ lower_right = lower_right * causal_mask
397
+
398
+ # Upper right of attention matrix (zeroes):
399
+ upper_right = torch.zeros(prompt_seq_len, report_seq_len, dtype=torch.long, device=causal_1d_attention_mask.device)
400
+
401
+ # Lower left of attention matrix:
402
+ lower_left = non_causal_1d_attention_mask[None, :] * causal_1d_attention_mask[:, None]
403
+
404
+ # Concatenate blocks:
405
+ left = torch.cat((upper_left, lower_left), dim=0)
406
+ right = torch.cat((upper_right, lower_right), dim=0)
407
+ mixed_causality_3d_attention_mask = torch.cat((left, right), dim=-1)
408
+ else:
409
+ mixed_causality_3d_attention_mask = upper_left
410
+
411
+ # Convert dtype and apply masking rules:
412
+ mixed_causality_3d_attention_mask = mixed_causality_3d_attention_mask.to(dtype=dtype)
413
+ mixed_causality_3d_attention_mask[mixed_causality_3d_attention_mask == 0] = torch.finfo(mixed_causality_3d_attention_mask.dtype).min
414
+ mixed_causality_3d_attention_mask[mixed_causality_3d_attention_mask == 1] = 0.0
415
+
416
+ # Add head dimension:
417
+ mixed_causality_3d_attention_mask = mixed_causality_3d_attention_mask.unsqueeze(0)
418
+
419
+ return mixed_causality_3d_attention_mask
420
+
421
+ def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
422
+
423
+ # Set NaNs to inf_time_delta_value:
424
+ time_deltas = torch.nan_to_num(time_deltas, nan=self.inf_time_delta_value)
425
+
426
+ # Convert attention mask to 2D if it is 4D:
427
+ if attention_mask.dim() == 4:
428
+ attention_mask = (attention_mask.squeeze(1).diagonal(dim1=1, dim2=2) == 0.0).long()
429
+
430
+ # Set time deltas to NaN where the attention mask is 0:
431
+ mask_value = float('inf') if self.time_delta_monotonic_inversion else -float('inf')
432
+ masked_time_deltas = torch.where(attention_mask == 1, time_deltas, mask_value)
433
+
434
+ # Sort time deltas and get indices
435
+ sorted_time_deltas, col_indices = masked_time_deltas.sort(
436
+ dim=1, descending=not self.time_delta_monotonic_inversion, stable=True
437
+ )
438
+
439
+ num_rows, num_cols = time_deltas.shape
440
+
441
+ row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
442
+ position_ids = torch.zeros_like(col_indices, device=time_deltas.device)
443
+ position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
444
+
445
+ # Apply the attention mask to zero out invalid positions
446
+ position_ids = position_ids.masked_fill(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285.
447
+
448
+ for i in range(position_ids.shape[0]):
449
+ assert self.validate_position_ids(position_ids[i])
450
+
451
+ return position_ids
452
+
453
+ @staticmethod
454
+ def validate_position_ids(tensor, repeat_value=1):
455
+ unique, counts = torch.unique(tensor, return_counts=True)
456
+
457
+ # Check if all integers from 0 to tensor.max() exist:
458
+ full_range = torch.arange(0, tensor.max() + 1, device=tensor.device)
459
+ if not torch.equal(unique.sort()[0], full_range):
460
+ return False
461
+
462
+ # Check for repeated values except for repeat_value:
463
+ repeated = unique[counts > 1]
464
+ if repeated.nelement() == 0:
465
+ return True
466
+ if not (repeated.numel() == 1 and repeated.item() == repeat_value):
467
+ return False
468
+
469
+ return True
470
+
471
+ def batch_decode(self, *args, **kwargs):
472
+ return self.tokenizer.batch_decode(*args, **kwargs)
473
+
474
+ def decode(self, *args, **kwargs):
475
+ return self.tokenizer.decode(*args, **kwargs)
476
+
477
+ @property
478
+ def model_input_names(self):
479
+ tokenizer_input_names = self.tokenizer.model_input_names
480
+ image_processor_input_names = self.image_processor.model_input_names
481
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
482
+
483
+ def split_and_decode_sections(self, token_ids):
484
+ """
485
+ Split the token identifiers into sections, then convert the token identifiers into strings.
486
+
487
+ Argument/s:
488
+ token_ids - token identifiers.
489
+
490
+ Returns:
491
+ token_type_ids - token type identifiers.
492
+ """
493
+
494
+ sections = {'findings': [], 'impression': []}
495
+ for i in token_ids:
496
+ findings_start_idx = (i == self.tokenizer.bos_token_id).int().argmax().item()
497
+ findings_end_idx = (i == self.tokenizer.sep_token_id).int().argmax().item()
498
+ sections['findings'].append(self.tokenizer.decode(i[findings_start_idx:findings_end_idx], skip_special_tokens=True))
499
+ impression_start_idx = findings_end_idx + 1
500
+ impression_end_idx = (i == self.tokenizer.eos_token_id).int().argmax().item()
501
+ sections['impression'].append(self.tokenizer.decode(i[impression_start_idx:impression_end_idx], skip_special_tokens=True))
502
+
503
+ return tuple(sections.values())
504
+
505
+ def update_batch_for_rl(self, batch, completion_ids):
506
+
507
+ batch_size, prompt_len = batch['token_type_ids'].shape
508
+
509
+ # Number of completion tokens:
510
+ num_completion_tokens = completion_ids.shape[1] - prompt_len - 1 # -1 for offset between input and label ids.
511
+
512
+ # Update mask for completion tokens:
513
+ completion_mask = (completion_ids[:,-(num_completion_tokens + 1):] != self.tokenizer.pad_token_id).float() # +1 to ignore offset.
514
+ batch['completion_mask'] = completion_mask
515
+ completion_mask_expanded = completion_mask[:, None, None, 1:] # Start from 1 to reintroduce offset.
516
+ completion_mask_expanded_t = completion_mask[:, None, 1:, None] # Start from 1 to reintroduce offset.
517
+
518
+ upper_right = torch.zeros(batch_size, 1, prompt_len, num_completion_tokens, dtype=batch['initial_attention_mask'].dtype, device=completion_ids.device)
519
+
520
+ bottom_right = torch.tril(torch.ones(num_completion_tokens, num_completion_tokens, device=completion_ids.device)).bool()
521
+ bottom_right = bottom_right.unsqueeze(0).unsqueeze(0)
522
+ bottom_right = bottom_right.expand(batch_size, -1, -1, -1)
523
+ bottom_right = bottom_right * completion_mask_expanded * completion_mask_expanded_t
524
+
525
+ lower_left = batch['attention_mask'][:, None, None, :]
526
+ lower_left = lower_left.expand(-1, -1, num_completion_tokens, -1)
527
+ lower_left = lower_left * completion_mask_expanded_t
528
+
529
+ right = torch.cat((upper_right, bottom_right), dim=2)
530
+ right[right == 0] = torch.finfo(right.dtype).min
531
+ right[right == 1] = 0.0
532
+
533
+ lower_left[lower_left == 0] = torch.finfo(lower_left.dtype).min
534
+ lower_left[lower_left == 1] = 0.0
535
+
536
+ batch['attention_mask'] = torch.cat((batch['initial_attention_mask'], lower_left), dim=2)
537
+ batch['attention_mask'] = torch.cat((batch['attention_mask'], right), dim=3)
538
+
539
+ # initial_attention_mask was the 4D attention mask, whereas attention_mask was the 2D attention mask (i.e., not needed now that attention_mask is 4D):
540
+ batch.pop('initial_attention_mask', None)
541
+
542
+ # Convert remaining batch elements:
543
+ new_token_type_ids = torch.stack([self.token_ids_to_token_type_ids(
544
+ token_ids=i[-num_completion_tokens:],
545
+ # special_token_ids=[self.tokenizer.sep_token_id],
546
+ # token_type_id_sections=[self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['findings']), self.tokenizer.convert_tokens_to_ids(self.token_type_to_token['impression'])],
547
+ ) for i in completion_ids])
548
+ batch['token_type_ids'] = torch.cat((batch['token_type_ids'], new_token_type_ids), dim=1)
549
+ batch['time_deltas'] = torch.nn.functional.pad(batch['time_deltas'], (0, num_completion_tokens), value=0.0)
550
+ batch['time_deltas_mask'] = torch.nn.functional.pad(batch['time_deltas_mask'], (0, num_completion_tokens), value=0.0)
551
+
552
+ start_values = batch['position_ids'].max(dim=1).values + 1
553
+ end_values = start_values + num_completion_tokens
554
+ position_ids = torch.stack([torch.arange(i, j, device=batch['position_ids'].device) for i, j in zip(start_values, end_values)])
555
+ batch['position_ids'] = torch.cat((batch['position_ids'], position_ids), dim=1)
556
+
557
+ batch['label_ids'] = completion_ids[:, 1:].clone()
558
+ batch['input_ids'] = completion_ids[:, :-1]
559
+
560
+ # Convert token identifiers that weren't sampled to pad_token_id:
561
+ for i in range(batch_size):
562
+ idx = (batch['label_ids'][i] == self.tokenizer.bos_token_id).nonzero(as_tuple=False)[0, 0].item()
563
+ batch['label_ids'][i][:idx+1] = self.tokenizer.pad_token_id
564
+
565
+
566
+ return batch
processor_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
 
 
 
2
  "convert_to_rgb": false,
3
  "embeddings_per_image": 128,
4
  "generate_findings_token": "<|reserved_special_token_1|>",
 
1
  {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_cxrmate2.CXRMate2Processor"
4
+ },
5
  "convert_to_rgb": false,
6
  "embeddings_per_image": 128,
7
  "generate_findings_token": "<|reserved_special_token_1|>",
tokenizer_config.json CHANGED
@@ -2049,6 +2049,9 @@
2049
  "special": true
2050
  }
2051
  },
 
 
 
2052
  "bos_token": "<|begin_of_text|>",
2053
  "clean_up_tokenization_spaces": true,
2054
  "eos_token": "<|end_of_text|>",
 
2049
  "special": true
2050
  }
2051
  },
2052
+ "auto_map": {
2053
+ "AutoProcessor": "processing_cxrmate2.CXRMate2Processor"
2054
+ },
2055
  "bos_token": "<|begin_of_text|>",
2056
  "clean_up_tokenization_spaces": true,
2057
  "eos_token": "<|end_of_text|>",