anicolson commited on
Commit
e114fa5
·
verified ·
1 Parent(s): 1246380

Update processing_cxrmate2.py

Browse files
Files changed (1) hide show
  1. processing_cxrmate2.py +6 -3
processing_cxrmate2.py CHANGED
@@ -4,6 +4,7 @@ import random
4
  from io import BytesIO
5
  from typing import Dict, List, Union
6
 
 
7
  import numpy as np
8
  import requests
9
  import torch
@@ -15,7 +16,6 @@ from transformers.feature_extraction_utils import BatchFeature
15
  from transformers.image_utils import ImageInput
16
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
17
  from utils import compute_time_delta
18
- import cv2
19
 
20
  try:
21
  from .dataset import CXRMate2Dataset
@@ -277,7 +277,10 @@ class CXRMate2Processor(transformers.ProcessorMixin):
277
  batch['time_deltas_mask'][i].append(torch.tensor([0.0], dtype=torch.float32))
278
 
279
  # Map the image time delta values using the time delta map:
280
- image_time_deltas = [[self.time_delta_map(compute_time_delta(j, k)) if j is not None else float('nan') for j in i] for i, k in zip(image_datetime, study_datetime, strict=True)]
 
 
 
281
 
282
  # Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
283
  for i in range(len(images)):
@@ -632,4 +635,4 @@ class CXRMate2Processor(transformers.ProcessorMixin):
632
  return batch
633
 
634
  def wrap_dataset(self, dataset):
635
- return CXRMate2Dataset(dataset)
 
4
  from io import BytesIO
5
  from typing import Dict, List, Union
6
 
7
+ import cv2
8
  import numpy as np
9
  import requests
10
  import torch
 
16
  from transformers.image_utils import ImageInput
17
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
  from utils import compute_time_delta
 
19
 
20
  try:
21
  from .dataset import CXRMate2Dataset
 
277
  batch['time_deltas_mask'][i].append(torch.tensor([0.0], dtype=torch.float32))
278
 
279
  # Map the image time delta values using the time delta map:
280
+ if study_datetime is not None:
281
+ image_time_deltas = [[self.time_delta_map(compute_time_delta(j, k)) if j is not None else float('nan') for j in i] for i, k in zip(image_datetime, study_datetime, strict=True)]
282
+ else:
283
+ image_time_deltas = [[float('nan') for _ in range(len(i))] for i in images]
284
 
285
  # Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
286
  for i in range(len(images)):
 
635
  return batch
636
 
637
  def wrap_dataset(self, dataset):
638
+ return CXRMate2Dataset(dataset)