import shutil import traceback from io import BytesIO from urllib.parse import urlparse import cv2 import numpy as np import pydicom import requests import torch import torch.nn.functional as F from PIL import Image from transformers import BitImageProcessor, BlipImageProcessor @torch.no_grad() def model_inference(image, text, model, image_processor, tokenizer): image = load_image(image) (width, height) = image.size image_size = (height, width) image_processor_outputs = image_processor(image) processed_image = torch.FloatTensor( np.array(image_processor_outputs["pixel_values"]) ).to(model.device) tokenized_text = tokenizer( text, padding=True, truncation=True, return_tensors="pt", ).to(model.device) output = model.compute_logits(processed_image, [tokenized_text]) logits = output["logits"] similarity_prob = logits.sigmoid() similarity_scores = output["similarity_scores"] similarity_scores = similarity_scores.view(-1) similarity_scores = interpolate_similarity_scores( similarity_scores, image_size, image_processor ) similarity_map = similarity_scores.sigmoid()[0] return similarity_prob, similarity_map @torch.no_grad() def model_inference_multiple_text(image, text_list, model, image_processor, tokenizer): # TODO: batch inference probs, similarity_maps = [], [] for text in text_list: prob, similarity_map = model_inference( image, text, model, image_processor, tokenizer ) probs.append(prob) similarity_maps.append(similarity_map) return torch.stack(probs), torch.stack(similarity_maps) def interpolate_similarity_scores(similarity_scores, origin_size, image_processor): (height, width) = origin_size patch_size = int(similarity_scores.shape[-1] ** 0.5) scores = similarity_scores.view(1, 1, patch_size, patch_size) if isinstance(image_processor, BlipImageProcessor): # XrayDINOv2 interpolated_scores = F.interpolate( scores, size=(height, width), mode="bilinear", align_corners=False, ) interpolated_scores = interpolated_scores.squeeze(1) elif isinstance(image_processor, BitImageProcessor): shortest = min(height, width) interpolated_scores = F.interpolate( scores, size=(shortest, shortest), mode="bilinear", align_corners=False, ) cropped_left = (width - shortest) // 2 cropped_top = (height - shortest) // 2 original_size_map = torch.ones(height, width) * -999 original_size_map[ cropped_top : cropped_top + shortest, cropped_left : cropped_left + shortest ] = interpolated_scores.view(shortest, shortest) interpolated_scores = original_size_map interpolated_scores = interpolated_scores.unsqueeze(0) return interpolated_scores # copy from https://github.com/MIT-LCP/mimic-code/issues/1013 def dicom_to_pil_image(input_file_path, save_dir=None): """ Extract the image from a DICOM file and return it as a PIL.Image object. Args: input_file_path (str): Path to the input DICOM file. Returns: PIL.Image.Image: Processed image. """ try: # Read the DICOM and extract the image. dcm_file = pydicom.dcmread(input_file_path) raw_image = dcm_file.pixel_array assert len(raw_image.shape) == 2, "Expecting single channel (grayscale) image." # Normalize pixels to be in [0, 255]. raw_image = raw_image - raw_image.min() normalized_image = raw_image / raw_image.max() rescaled_image = (normalized_image * 255).astype(np.uint8) # Correct image inversion. if dcm_file.PhotometricInterpretation == "MONOCHROME1": rescaled_image = cv2.bitwise_not(rescaled_image) # Perform histogram equalization. final_image = cv2.equalizeHist(rescaled_image) # Convert to PIL Image and return image = Image.fromarray(final_image) if save_dir is not None: shutil.copy2(input_file_path, save_dir) return image except Exception: print(traceback.format_exc()) def load_image(image): """ Load an image from a file path or a PIL.Image object. Args: image (str or PIL.Image.Image): Path to the image file or a PIL.Image object. Returns: PIL.Image.Image: Processed image. """ if isinstance(image, str): if image.lower().endswith(".dcm"): image = dicom_to_pil_image(image) elif ( image.lower().endswith(".png") or image.lower().endswith(".jpg") or image.lower().endswith(".jpeg") ): image = Image.open(image) else: raise ValueError(f"Invalid image type: {image}") elif not isinstance(image, Image.Image): raise ValueError(f"Invalid image type: {type(image)}") return image