Update processing_cxrmate2.py
Browse files- processing_cxrmate2.py +6 -3
processing_cxrmate2.py
CHANGED
|
@@ -4,6 +4,7 @@ import random
|
|
| 4 |
from io import BytesIO
|
| 5 |
from typing import Dict, List, Union
|
| 6 |
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import requests
|
| 9 |
import torch
|
|
@@ -15,7 +16,6 @@ from transformers.feature_extraction_utils import BatchFeature
|
|
| 15 |
from transformers.image_utils import ImageInput
|
| 16 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 17 |
from utils import compute_time_delta
|
| 18 |
-
import cv2
|
| 19 |
|
| 20 |
try:
|
| 21 |
from .dataset import CXRMate2Dataset
|
|
@@ -277,7 +277,10 @@ class CXRMate2Processor(transformers.ProcessorMixin):
|
|
| 277 |
batch['time_deltas_mask'][i].append(torch.tensor([0.0], dtype=torch.float32))
|
| 278 |
|
| 279 |
# Map the image time delta values using the time delta map:
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
|
| 283 |
for i in range(len(images)):
|
|
@@ -632,4 +635,4 @@ class CXRMate2Processor(transformers.ProcessorMixin):
|
|
| 632 |
return batch
|
| 633 |
|
| 634 |
def wrap_dataset(self, dataset):
|
| 635 |
-
return CXRMate2Dataset(dataset)
|
|
|
|
| 4 |
from io import BytesIO
|
| 5 |
from typing import Dict, List, Union
|
| 6 |
|
| 7 |
+
import cv2
|
| 8 |
import numpy as np
|
| 9 |
import requests
|
| 10 |
import torch
|
|
|
|
| 16 |
from transformers.image_utils import ImageInput
|
| 17 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 18 |
from utils import compute_time_delta
|
|
|
|
| 19 |
|
| 20 |
try:
|
| 21 |
from .dataset import CXRMate2Dataset
|
|
|
|
| 277 |
batch['time_deltas_mask'][i].append(torch.tensor([0.0], dtype=torch.float32))
|
| 278 |
|
| 279 |
# Map the image time delta values using the time delta map:
|
| 280 |
+
if study_datetime is not None:
|
| 281 |
+
image_time_deltas = [[self.time_delta_map(compute_time_delta(j, k)) if j is not None else float('nan') for j in i] for i, k in zip(image_datetime, study_datetime, strict=True)]
|
| 282 |
+
else:
|
| 283 |
+
image_time_deltas = [[float('nan') for _ in range(len(i))] for i in images]
|
| 284 |
|
| 285 |
# Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
|
| 286 |
for i in range(len(images)):
|
|
|
|
| 635 |
return batch
|
| 636 |
|
| 637 |
def wrap_dataset(self, dataset):
|
| 638 |
+
return CXRMate2Dataset(dataset)
|