support cpu
Browse files- README.md +3 -3
- __pycache__/utils.cpython-39.pyc +0 -0
- demo.py +10 -42
- inference.py +6 -6
- utils.py +7 -8
README.md
CHANGED
|
@@ -22,12 +22,12 @@ conda activate medversa
|
|
| 22 |
## Inference
|
| 23 |
``` python
|
| 24 |
from utils import *
|
|
|
|
| 25 |
|
| 26 |
# --- Launch Model ---
|
| 27 |
-
device = 'cuda
|
| 28 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
| 29 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
| 30 |
-
model.eval()
|
| 31 |
|
| 32 |
# --- Define examples ---
|
| 33 |
examples = [
|
|
|
|
| 22 |
## Inference
|
| 23 |
``` python
|
| 24 |
from utils import *
|
| 25 |
+
from torch import cuda
|
| 26 |
|
| 27 |
# --- Launch Model ---
|
| 28 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
| 29 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
| 30 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
|
| 31 |
|
| 32 |
# --- Define examples ---
|
| 33 |
examples = [
|
__pycache__/utils.cpython-39.pyc
CHANGED
|
Binary files a/__pycache__/utils.cpython-39.pyc and b/__pycache__/utils.cpython-39.pyc differ
|
|
|
demo.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import argparse
|
| 3 |
import torch
|
|
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
import torchvision.transforms.functional as TF
|
| 6 |
from torchvision import transforms
|
|
@@ -32,15 +33,14 @@ def parse_args():
|
|
| 32 |
args = parser.parse_args()
|
| 33 |
return args
|
| 34 |
|
| 35 |
-
device = 'cuda
|
| 36 |
# Launch model
|
| 37 |
args = parse_args()
|
| 38 |
cfg = Config(args)
|
| 39 |
|
| 40 |
model_config = cfg.model_cfg
|
| 41 |
model_cls = registry.get_model_class(model_config.arch)
|
| 42 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
| 43 |
-
model.eval()
|
| 44 |
global global_images
|
| 45 |
global_images = None
|
| 46 |
|
|
@@ -146,7 +146,7 @@ def task_seg_2d(model, preds, hidden_states, image):
|
|
| 146 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
| 147 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
| 148 |
seg_probs = F.sigmoid(seg_preds)
|
| 149 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
| 150 |
return seg_mask
|
| 151 |
else:
|
| 152 |
return None
|
|
@@ -165,7 +165,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
|
| 165 |
new_img_embeds_list[-1] = last_feats
|
| 166 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
| 167 |
seg_probs = F.sigmoid(seg_preds)
|
| 168 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
| 169 |
return seg_mask
|
| 170 |
|
| 171 |
def task_det_2d(model, preds, hidden_states):
|
|
@@ -175,7 +175,7 @@ def task_det_2d(model, preds, hidden_states):
|
|
| 175 |
if target_states:
|
| 176 |
target_states = torch.cat(target_states).squeeze()
|
| 177 |
det_states = model.text_det(target_states).detach().cpu()
|
| 178 |
-
return det_states.numpy()
|
| 179 |
return torch.zeros_like(indices)
|
| 180 |
|
| 181 |
class StoppingCriteriaSub(StoppingCriteria):
|
|
@@ -240,7 +240,7 @@ def load_and_preprocess_image(image):
|
|
| 240 |
transforms.ToTensor(),
|
| 241 |
transforms.Normalize(mean, std)
|
| 242 |
])
|
| 243 |
-
image = transform(image).type(torch.bfloat16).
|
| 244 |
return image
|
| 245 |
|
| 246 |
def load_and_preprocess_volume(image):
|
|
@@ -249,7 +249,7 @@ def load_and_preprocess_volume(image):
|
|
| 249 |
transform = tio.Compose([
|
| 250 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
| 251 |
])
|
| 252 |
-
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
| 253 |
return image
|
| 254 |
|
| 255 |
def read_image(image_path):
|
|
@@ -328,14 +328,14 @@ def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_
|
|
| 328 |
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 329 |
num_imgs = len(images)
|
| 330 |
modal = modality.lower()
|
| 331 |
-
image_tensors = [read_image(img) for img in images]
|
| 332 |
if modality == 'ct':
|
| 333 |
time.sleep(2)
|
| 334 |
else:
|
| 335 |
time.sleep(1)
|
| 336 |
image_tensor = torch.cat(image_tensors)
|
| 337 |
|
| 338 |
-
with torch.autocast(
|
| 339 |
with torch.no_grad():
|
| 340 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 341 |
|
|
@@ -388,38 +388,6 @@ def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_s
|
|
| 388 |
|
| 389 |
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
| 390 |
|
| 391 |
-
# my_dict = {}
|
| 392 |
-
# def gradio_interface(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 393 |
-
# if not images:
|
| 394 |
-
# return None, "Error: At least one image is required to proceed."
|
| 395 |
-
# if not prompt or not task or not modality:
|
| 396 |
-
# return None, "Error: Please provide prompt, select task and modality to proceed."
|
| 397 |
-
|
| 398 |
-
# generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, task, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 399 |
-
# output_images = []
|
| 400 |
-
|
| 401 |
-
# input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
|
| 402 |
-
# if generated_images is not None:
|
| 403 |
-
# for generated_image in generated_images:
|
| 404 |
-
# output_images.append(np.asarray(generated_image).astype(np.uint8))
|
| 405 |
-
# snapshot = (output_images[0], [])
|
| 406 |
-
# if seg_mask_2d is not None:
|
| 407 |
-
# snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
|
| 408 |
-
# if seg_mask_3d is not None:
|
| 409 |
-
# snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
|
| 410 |
-
# else:
|
| 411 |
-
# output_images = input_images.copy()
|
| 412 |
-
# snapshot = (output_images[0], [])
|
| 413 |
-
|
| 414 |
-
# my_dict['image'] = output_images
|
| 415 |
-
# my_dict['mask'] = None
|
| 416 |
-
# if seg_mask_2d is not None:
|
| 417 |
-
# my_dict['mask'] = seg_mask_2d
|
| 418 |
-
# if seg_mask_3d is not None:
|
| 419 |
-
# my_dict['mask'] = seg_mask_3d
|
| 420 |
-
|
| 421 |
-
# return output_text, snapshot, gr.update(maximum=len(output_images)-1)
|
| 422 |
-
|
| 423 |
def render(x):
|
| 424 |
if x > len(my_dict['image'])-1:
|
| 425 |
x = len(my_dict['image'])-1
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import argparse
|
| 3 |
import torch
|
| 4 |
+
from torch import cuda
|
| 5 |
import torch.nn.functional as F
|
| 6 |
import torchvision.transforms.functional as TF
|
| 7 |
from torchvision import transforms
|
|
|
|
| 33 |
args = parser.parse_args()
|
| 34 |
return args
|
| 35 |
|
| 36 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
| 37 |
# Launch model
|
| 38 |
args = parse_args()
|
| 39 |
cfg = Config(args)
|
| 40 |
|
| 41 |
model_config = cfg.model_cfg
|
| 42 |
model_cls = registry.get_model_class(model_config.arch)
|
| 43 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
|
| 44 |
global global_images
|
| 45 |
global_images = None
|
| 46 |
|
|
|
|
| 146 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
| 147 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
| 148 |
seg_probs = F.sigmoid(seg_preds)
|
| 149 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 150 |
return seg_mask
|
| 151 |
else:
|
| 152 |
return None
|
|
|
|
| 165 |
new_img_embeds_list[-1] = last_feats
|
| 166 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
| 167 |
seg_probs = F.sigmoid(seg_preds)
|
| 168 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 169 |
return seg_mask
|
| 170 |
|
| 171 |
def task_det_2d(model, preds, hidden_states):
|
|
|
|
| 175 |
if target_states:
|
| 176 |
target_states = torch.cat(target_states).squeeze()
|
| 177 |
det_states = model.text_det(target_states).detach().cpu()
|
| 178 |
+
return det_states.to(torch.float32).numpy()
|
| 179 |
return torch.zeros_like(indices)
|
| 180 |
|
| 181 |
class StoppingCriteriaSub(StoppingCriteria):
|
|
|
|
| 240 |
transforms.ToTensor(),
|
| 241 |
transforms.Normalize(mean, std)
|
| 242 |
])
|
| 243 |
+
image = transform(image).type(torch.bfloat16).unsqueeze(0)
|
| 244 |
return image
|
| 245 |
|
| 246 |
def load_and_preprocess_volume(image):
|
|
|
|
| 249 |
transform = tio.Compose([
|
| 250 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
| 251 |
])
|
| 252 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
| 253 |
return image
|
| 254 |
|
| 255 |
def read_image(image_path):
|
|
|
|
| 328 |
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 329 |
num_imgs = len(images)
|
| 330 |
modal = modality.lower()
|
| 331 |
+
image_tensors = [read_image(img).to(device) for img in images]
|
| 332 |
if modality == 'ct':
|
| 333 |
time.sleep(2)
|
| 334 |
else:
|
| 335 |
time.sleep(1)
|
| 336 |
image_tensor = torch.cat(image_tensors)
|
| 337 |
|
| 338 |
+
with torch.autocast(device):
|
| 339 |
with torch.no_grad():
|
| 340 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 341 |
|
|
|
|
| 388 |
|
| 389 |
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
def render(x):
|
| 392 |
if x > len(my_dict['image'])-1:
|
| 393 |
x = len(my_dict['image'])-1
|
inference.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
from utils import *
|
|
|
|
| 2 |
|
| 3 |
# --- Launch Model ---
|
| 4 |
-
device = 'cuda
|
| 5 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
| 6 |
-
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device)
|
| 7 |
-
model.eval()
|
| 8 |
|
| 9 |
# --- Define examples ---
|
| 10 |
examples = [
|
|
@@ -85,14 +85,14 @@ temperature = 0.1
|
|
| 85 |
index = 0
|
| 86 |
demo_ex = examples[index]
|
| 87 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 88 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 89 |
print(output_text)
|
| 90 |
|
| 91 |
# --- Segment the lesion in the dermatology image ---
|
| 92 |
index = 6
|
| 93 |
demo_ex = examples[index]
|
| 94 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 95 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 96 |
print(output_text)
|
| 97 |
print(seg_mask_2d[0].shape) # H, W
|
| 98 |
|
|
@@ -100,7 +100,7 @@ print(seg_mask_2d[0].shape) # H, W
|
|
| 100 |
index = -2
|
| 101 |
demo_ex = examples[index]
|
| 102 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 103 |
-
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 104 |
print(output_text)
|
| 105 |
print(len(seg_mask_3d)) # Number of slices
|
| 106 |
print(seg_mask_3d[0].shape) # H, W
|
|
|
|
| 1 |
from utils import *
|
| 2 |
+
from torch import cuda
|
| 3 |
|
| 4 |
# --- Launch Model ---
|
| 5 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
| 6 |
model_cls = registry.get_model_class('medomni') # medomni is the architecture name :)
|
| 7 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
|
|
|
| 8 |
|
| 9 |
# --- Define examples ---
|
| 10 |
examples = [
|
|
|
|
| 85 |
index = 0
|
| 86 |
demo_ex = examples[index]
|
| 87 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 88 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
| 89 |
print(output_text)
|
| 90 |
|
| 91 |
# --- Segment the lesion in the dermatology image ---
|
| 92 |
index = 6
|
| 93 |
demo_ex = examples[index]
|
| 94 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 95 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
| 96 |
print(output_text)
|
| 97 |
print(seg_mask_2d[0].shape) # H, W
|
| 98 |
|
|
|
|
| 100 |
index = -2
|
| 101 |
demo_ex = examples[index]
|
| 102 |
images, context, prompt, modality, task = demo_ex[0], demo_ex[1], demo_ex[2], demo_ex[3], demo_ex[4]
|
| 103 |
+
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
|
| 104 |
print(output_text)
|
| 105 |
print(len(seg_mask_3d)) # Number of slices
|
| 106 |
print(seg_mask_3d[0].shape) # H, W
|
utils.py
CHANGED
|
@@ -133,7 +133,7 @@ def task_seg_2d(model, preds, hidden_states, image):
|
|
| 133 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
| 134 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
| 135 |
seg_probs = F.sigmoid(seg_preds)
|
| 136 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
| 137 |
return seg_mask
|
| 138 |
else:
|
| 139 |
return None
|
|
@@ -152,7 +152,7 @@ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
|
| 152 |
new_img_embeds_list[-1] = last_feats
|
| 153 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
| 154 |
seg_probs = F.sigmoid(seg_preds)
|
| 155 |
-
seg_mask = seg_probs.cpu().squeeze().numpy() >= 0.5
|
| 156 |
return seg_mask
|
| 157 |
|
| 158 |
def task_det_2d(model, preds, hidden_states):
|
|
@@ -227,7 +227,7 @@ def load_and_preprocess_image(image):
|
|
| 227 |
transforms.ToTensor(),
|
| 228 |
transforms.Normalize(mean, std)
|
| 229 |
])
|
| 230 |
-
image = transform(image).type(torch.bfloat16).
|
| 231 |
return image
|
| 232 |
|
| 233 |
def load_and_preprocess_volume(image):
|
|
@@ -236,7 +236,7 @@ def load_and_preprocess_volume(image):
|
|
| 236 |
transform = tio.Compose([
|
| 237 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
| 238 |
])
|
| 239 |
-
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
| 240 |
return image
|
| 241 |
|
| 242 |
def read_image(image_path):
|
|
@@ -285,7 +285,6 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
|
|
| 285 |
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
| 286 |
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
| 287 |
if sum(preds == model.seg_token_idx_3d):
|
| 288 |
-
ipdb.set_trace()
|
| 289 |
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
| 290 |
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
| 291 |
if sum(preds == model.det_token_idx):
|
|
@@ -304,17 +303,17 @@ def generate(model, image_path, image, context, modal, task, num_imgs, prompt, n
|
|
| 304 |
output_text = 'The main diagnosis is melanoma.'
|
| 305 |
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
| 306 |
|
| 307 |
-
def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 308 |
num_imgs = len(images)
|
| 309 |
modal = modality.lower()
|
| 310 |
-
image_tensors = [read_image(img) for img in images]
|
| 311 |
if modality == 'ct':
|
| 312 |
time.sleep(2)
|
| 313 |
else:
|
| 314 |
time.sleep(1)
|
| 315 |
image_tensor = torch.cat(image_tensors)
|
| 316 |
|
| 317 |
-
with torch.autocast(
|
| 318 |
with torch.no_grad():
|
| 319 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 320 |
|
|
|
|
| 133 |
seg_feats = model.model_seg_2d.decoder(*feats)
|
| 134 |
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
| 135 |
seg_probs = F.sigmoid(seg_preds)
|
| 136 |
+
seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 137 |
return seg_mask
|
| 138 |
else:
|
| 139 |
return None
|
|
|
|
| 152 |
new_img_embeds_list[-1] = last_feats
|
| 153 |
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
| 154 |
seg_probs = F.sigmoid(seg_preds)
|
| 155 |
+
seg_mask = seg_probs.to(dtype=torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 156 |
return seg_mask
|
| 157 |
|
| 158 |
def task_det_2d(model, preds, hidden_states):
|
|
|
|
| 227 |
transforms.ToTensor(),
|
| 228 |
transforms.Normalize(mean, std)
|
| 229 |
])
|
| 230 |
+
image = transform(image).type(torch.bfloat16).unsqueeze(0)
|
| 231 |
return image
|
| 232 |
|
| 233 |
def load_and_preprocess_volume(image):
|
|
|
|
| 236 |
transform = tio.Compose([
|
| 237 |
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
| 238 |
])
|
| 239 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
| 240 |
return image
|
| 241 |
|
| 242 |
def read_image(image_path):
|
|
|
|
| 285 |
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
| 286 |
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
| 287 |
if sum(preds == model.seg_token_idx_3d):
|
|
|
|
| 288 |
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
| 289 |
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
| 290 |
if sum(preds == model.det_token_idx):
|
|
|
|
| 303 |
output_text = 'The main diagnosis is melanoma.'
|
| 304 |
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
| 305 |
|
| 306 |
+
def generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device):
|
| 307 |
num_imgs = len(images)
|
| 308 |
modal = modality.lower()
|
| 309 |
+
image_tensors = [read_image(img).to(device) for img in images]
|
| 310 |
if modality == 'ct':
|
| 311 |
time.sleep(2)
|
| 312 |
else:
|
| 313 |
time.sleep(1)
|
| 314 |
image_tensor = torch.cat(image_tensors)
|
| 315 |
|
| 316 |
+
with torch.autocast(device):
|
| 317 |
with torch.no_grad():
|
| 318 |
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 319 |
|