Safetensors
custom_code
jonggwon-park commited on
Commit
5718512
·
1 Parent(s): 6205663

implement inference code

Browse files
Files changed (2) hide show
  1. inference.py +51 -0
  2. utils.py +167 -0
inference.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
5
+
6
+ from utils import model_inference
7
+
8
+ # Suppress specific warnings for cleaner logs
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
+
11
+
12
+ def load_model(device, dtype):
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("Deepnoid/RadZero")
15
+ image_processor = AutoImageProcessor.from_pretrained("Deepnoid/RadZero")
16
+
17
+ model = AutoModel.from_pretrained(
18
+ "Deepnoid/RadZero",
19
+ trust_remote_code=True,
20
+ torch_dtype=dtype,
21
+ device_map=device,
22
+ )
23
+
24
+ models = {
25
+ "tokenizer": tokenizer,
26
+ "image_processor": image_processor,
27
+ "model": model,
28
+ }
29
+ return models
30
+
31
+
32
+ if __name__ == "__main__":
33
+ # Setup constant
34
+ device = torch.device("cuda")
35
+ dtype = torch.float32
36
+
37
+ # load models
38
+ models = load_model(device, dtype)
39
+
40
+ # load image
41
+ image_path = "cxr_image.jpg"
42
+
43
+ # inference
44
+ similarity_prob, similarity_map = model_inference(
45
+ image_path, "There is fibrosis", **models
46
+ )
47
+
48
+ print(similarity_prob)
49
+ print(similarity_map.min())
50
+ print(similarity_map.max())
51
+ print(similarity_map.shape)
utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import traceback
3
+ from io import BytesIO
4
+ from urllib.parse import urlparse
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import pydicom
9
+ import requests
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+ from transformers import BitImageProcessor, BlipImageProcessor
14
+
15
+
16
+ @torch.no_grad()
17
+ def model_inference(image, text, model, image_processor, tokenizer):
18
+ image = load_image(image)
19
+
20
+ (width, height) = image.size
21
+
22
+ image_size = (height, width)
23
+
24
+ image_processor_outputs = image_processor(image)
25
+
26
+ processed_image = torch.FloatTensor(
27
+ np.array(image_processor_outputs["pixel_values"])
28
+ ).to(model.device)
29
+
30
+ tokenized_text = tokenizer(
31
+ text,
32
+ padding=True,
33
+ truncation=True,
34
+ return_tensors="pt",
35
+ ).to(model.device)
36
+
37
+ output = model.compute_logits(processed_image, [tokenized_text])
38
+ logits = output["logits"]
39
+ similarity_prob = logits.sigmoid()
40
+
41
+ similarity_scores = output["similarity_scores"]
42
+ similarity_scores = similarity_scores.view(-1)
43
+
44
+ similarity_scores = interpolate_similarity_scores(
45
+ similarity_scores, image_size, image_processor
46
+ )
47
+ similarity_map = similarity_scores.sigmoid()[0]
48
+
49
+ return similarity_prob, similarity_map
50
+
51
+
52
+ @torch.no_grad()
53
+ def model_inference_multiple_text(image, text_list, model, image_processor, tokenizer):
54
+ # TODO: batch inference
55
+ probs, similarity_maps = [], []
56
+ for text in text_list:
57
+ prob, similarity_map = model_inference(
58
+ image, text, model, image_processor, tokenizer
59
+ )
60
+ probs.append(prob)
61
+ similarity_maps.append(similarity_map)
62
+
63
+ return torch.stack(probs), torch.stack(similarity_maps)
64
+
65
+
66
+ def interpolate_similarity_scores(similarity_scores, origin_size, image_processor):
67
+ (height, width) = origin_size
68
+ patch_size = int(similarity_scores.shape[-1] ** 0.5)
69
+ scores = similarity_scores.view(1, 1, patch_size, patch_size)
70
+
71
+ if isinstance(image_processor, BlipImageProcessor):
72
+ # XrayDINOv2
73
+ interpolated_scores = F.interpolate(
74
+ scores,
75
+ size=(height, width),
76
+ mode="bilinear",
77
+ align_corners=False,
78
+ )
79
+ interpolated_scores = interpolated_scores.squeeze(1)
80
+
81
+ elif isinstance(image_processor, BitImageProcessor):
82
+ shortest = min(height, width)
83
+
84
+ interpolated_scores = F.interpolate(
85
+ scores,
86
+ size=(shortest, shortest),
87
+ mode="bilinear",
88
+ align_corners=False,
89
+ )
90
+
91
+ cropped_left = (width - shortest) // 2
92
+ cropped_top = (height - shortest) // 2
93
+
94
+ original_size_map = torch.ones(height, width) * -999
95
+ original_size_map[
96
+ cropped_top : cropped_top + shortest, cropped_left : cropped_left + shortest
97
+ ] = interpolated_scores.view(shortest, shortest)
98
+
99
+ interpolated_scores = original_size_map
100
+ interpolated_scores = interpolated_scores.unsqueeze(0)
101
+
102
+ return interpolated_scores
103
+
104
+
105
+ # copy from https://github.com/MIT-LCP/mimic-code/issues/1013
106
+ def dicom_to_pil_image(input_file_path, save_dir=None):
107
+ """
108
+ Extract the image from a DICOM file and return it as a PIL.Image object.
109
+ Args:
110
+ input_file_path (str): Path to the input DICOM file.
111
+ Returns:
112
+ PIL.Image.Image: Processed image.
113
+ """
114
+ try:
115
+ # Read the DICOM and extract the image.
116
+ dcm_file = pydicom.dcmread(input_file_path)
117
+ raw_image = dcm_file.pixel_array
118
+
119
+ assert len(raw_image.shape) == 2, "Expecting single channel (grayscale) image."
120
+
121
+ # Normalize pixels to be in [0, 255].
122
+ raw_image = raw_image - raw_image.min()
123
+ normalized_image = raw_image / raw_image.max()
124
+ rescaled_image = (normalized_image * 255).astype(np.uint8)
125
+
126
+ # Correct image inversion.
127
+ if dcm_file.PhotometricInterpretation == "MONOCHROME1":
128
+ rescaled_image = cv2.bitwise_not(rescaled_image)
129
+
130
+ # Perform histogram equalization.
131
+ final_image = cv2.equalizeHist(rescaled_image)
132
+
133
+ # Convert to PIL Image and return
134
+ image = Image.fromarray(final_image)
135
+
136
+ if save_dir is not None:
137
+ shutil.copy2(input_file_path, save_dir)
138
+
139
+ return image
140
+ except Exception:
141
+ print(traceback.format_exc())
142
+
143
+
144
+ def load_image(image):
145
+ """
146
+ Load an image from a file path or a PIL.Image object.
147
+ Args:
148
+ image (str or PIL.Image.Image): Path to the image file or a PIL.Image object.
149
+ Returns:
150
+ PIL.Image.Image: Processed image.
151
+ """
152
+
153
+ if isinstance(image, str):
154
+ if image.lower().endswith(".dcm"):
155
+ image = dicom_to_pil_image(image)
156
+ elif (
157
+ image.lower().endswith(".png")
158
+ or image.lower().endswith(".jpg")
159
+ or image.lower().endswith(".jpeg")
160
+ ):
161
+ image = Image.open(image)
162
+ else:
163
+ raise ValueError(f"Invalid image type: {image}")
164
+ elif not isinstance(image, Image.Image):
165
+ raise ValueError(f"Invalid image type: {type(image)}")
166
+
167
+ return image