Spaces:
Sleeping
Sleeping
| import random | |
| import time | |
| from functools import partial | |
| from typing import List | |
| import deepinv as dinv | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric | |
| ### Config | |
| # run model inference on NVIDIA gpu if available | |
| DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| ### Gradio Utils | |
| def resize_tensor_within_box(tensor_img: torch.Tensor, max_size: int = 512): | |
| _, _, h, w = tensor_img.shape | |
| scale = min(max_size / h, max_size / w) | |
| if scale < 1.0: | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| tensor_img = transforms.functional.resize(tensor_img, [new_h, new_w], antialias=True) | |
| return tensor_img | |
| def generate_imgs_from_user(image, | |
| physics: PhysicsWithGenerator, # use_gen: bool, | |
| baseline: BaselineModel, model: EvalModel, | |
| metrics: List[Metric]): | |
| # Happens when user image is missing | |
| if image is None: | |
| return None, None, None, None, None, None, None, None | |
| # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR | |
| x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR) | |
| # Resize img within a 512x512 box | |
| x = resize_tensor_within_box(x) | |
| C = x.shape[1] | |
| if C == 3 and physics.name == 'CT': | |
| x = transforms.Grayscale(num_output_channels=1)(x) | |
| elif C == 3 and physics.name == 'MRI': # not working because MRI physics has a fixed img size | |
| x = transforms.Grayscale(num_output_channels=1)(x) | |
| x = torch.cat((x, torch.zeros_like(x)), dim=1) | |
| return generate_imgs(x, physics, True, baseline, model, metrics) | |
| def generate_imgs_from_dataset(dataset: EvalDataset, idx: int, | |
| physics: PhysicsWithGenerator, # use_gen: bool, | |
| baseline: BaselineModel, model: EvalModel, | |
| metrics: List[Metric]): | |
| ### Load 1 image | |
| x = dataset[idx] # shape : (C, H, W) | |
| x = x.unsqueeze(0) # shape : (1, C, H, W) | |
| return generate_imgs(x, physics, True, baseline, model, metrics) | |
| def generate_random_imgs_from_dataset(dataset: EvalDataset, | |
| physics: PhysicsWithGenerator, | |
| # use_gen: bool, | |
| baseline: BaselineModel, | |
| model: EvalModel, | |
| metrics: List[Metric]): | |
| idx = random.randint(0, len(dataset)-1) | |
| x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset( | |
| dataset, idx, physics, baseline, model, metrics | |
| ) | |
| return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline | |
| def generate_imgs(x: torch.Tensor, | |
| physics: PhysicsWithGenerator, use_gen: bool, | |
| baseline: BaselineModel, model: EvalModel, | |
| metrics: List[Metric]): | |
| print(f"[Before inference] CUDA current allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[Before inference] CUDA current reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
| print(f"[Before inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[Before inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB") | |
| if hasattr(physics.physics, 'tensor_size'): | |
| physics.physics.tensor_size = x.shape[1:] | |
| elif hasattr(physics.physics, 'imsize'): | |
| physics.physics.imsize = x.shape[1:] | |
| if physics.physics_generator is not None: # we only change physic params but not noise levels | |
| if hasattr(physics.physics_generator, 'tensor_size'): | |
| physics.physics_generator.tensor_size = x.shape[1:] | |
| physics.generator.tensor_size = x.shape[1:] | |
| if hasattr(physics.physics_generator, 'imsize'): | |
| physics.physics_generator.imsize = x.shape[1:] | |
| physics.generator.imsize = x.shape[1:] | |
| ### Compute y | |
| with torch.no_grad(): | |
| y = physics(x, use_gen) # possible reduction in img shape due to Blurring | |
| ### Compute x_hat from RAM & DPIR | |
| ram_time = time.time() | |
| with torch.no_grad(): | |
| out = model(y=y, physics=physics.physics) | |
| ram_time = time.time() - ram_time | |
| dpir_time = time.time() | |
| with torch.no_grad(): | |
| out_baseline = baseline(y=y, physics=physics.physics) | |
| dpir_time = time.time() - dpir_time | |
| ### Process tensors before metric computation | |
| if "Blur" in physics.name: | |
| w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 | |
| h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 | |
| x = x[..., w_1:w_2, h_1:h_2] | |
| out = out[..., w_1:w_2, h_1:h_2] | |
| if out_baseline.shape != out.shape: | |
| out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] | |
| ### Process y when y shape is different from x shape | |
| if physics.name == 'MRI' or physics.name == 'CT': | |
| y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) | |
| else: | |
| y_plot = y.clone() | |
| ### Metrics | |
| metrics_y = "" | |
| metrics_out = "" | |
| metrics_out_baseline = "" | |
| for metric in metrics: | |
| #if y.shape == x.shape: | |
| metrics_y += f"{metric.name} = {metric(y_plot, x).item():.4f}" + "\n" | |
| metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" | |
| metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" | |
| metrics_out += f"Inference time = {ram_time:.3f}s" | |
| metrics_out_baseline += f"Inference time = {dpir_time:.3f}s" | |
| ### Processing images for plotting : | |
| # - clip value outside of [0,1] | |
| # - shape (1, C, H, W) -> (C, H, W) | |
| # - torch.Tensor object -> Pil object | |
| process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") | |
| to_pil = transforms.ToPILImage() | |
| x_pil = to_pil(process_img(x)[0].to('cpu')) | |
| y_pil = to_pil(process_img(y_plot)[0].to('cpu')) | |
| out_pil = to_pil(process_img(out)[0].to('cpu')) | |
| out_baseline_pil = to_pil(process_img(out_baseline)[0].to('cpu')) | |
| ### Free memory | |
| del x, y, out, out_baseline, y_plot | |
| torch.cuda.empty_cache() | |
| print(f"[After inference] CUDA current allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[After inference] CUDA current reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
| print(f"[After inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[After inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB") | |
| return x_pil, y_pil, out_pil, out_baseline_pil, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline | |
| get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) | |
| get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) | |
| get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) | |
| def get_dataset(dataset_name): | |
| if dataset_name == 'MRI': | |
| available_physics = ['MRI'] | |
| physics_name = 'MRI' | |
| baseline_name = 'DPIR_MRI' | |
| elif dataset_name == 'CT': | |
| available_physics = ['CT'] | |
| physics_name = 'CT' | |
| baseline_name = 'DPIR_CT' | |
| else: | |
| available_physics = ['Inpainting', 'SR' ,'MotionBlur_medium', 'MotionBlur_hard', | |
| 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'] | |
| physics_name = 'MotionBlur_hard' | |
| baseline_name = 'DPIR' | |
| dataset = get_dataset_on_DEVICE_STR(dataset_name) | |
| idx = 0 | |
| physics = get_physics_on_DEVICE_STR(physics_name) | |
| baseline = get_baseline_model_on_DEVICE_STR(baseline_name) | |
| return dataset, idx, physics, baseline, available_physics | |
| # global variables shared by all users | |
| ram_model = EvalModel(device_str=DEVICE_STR) | |
| ram_model.eval() | |
| psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR) | |
| generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr) | |
| generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr) | |
| generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr) | |
| ### Gradio Blocks interface | |
| print(f"[Init] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[Init] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB") | |
| title = "Reconstruct Anything Model Demo" # displayed on gradio tab and in the gradio page | |
| with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface: | |
| gr.Markdown("## " + title) | |
| gr.Markdown( | |
| """ | |
| This demo showcases the performance of the **Reconstruct Anything Model (RAM)** across a variety of inverse problems on both natural and MRI images. | |
| Select a dataset and a physics task below (e.g., inpainting, super-resolution, deblurring...). | |
| Note: The parameters of the selected physics β such as noise levels, blur kernels, or inpainting masks β are randomly generated before reconstruction, leveraging the [deepinverse library](https://deepinv.github.io/deepinv/). | |
| π For more details on the method, check out our [paper on arXiv](https://arxiv.org/abs/2503.08915). | |
| """ | |
| ) | |
| ### USER-SPECIFIC VARIABLES | |
| dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural")) | |
| available_physics_placeholder = gr.State(['Inpainting', 'SR', 'MotionBlur_medium', 'MotionBlur_hard', | |
| 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']) | |
| # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method | |
| # Solution: using lambda expression | |
| physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_hard")) | |
| model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) | |
| print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") | |
| print(f"[Render] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB") | |
| def dynamic_layout(dataset, physics, available_physics): | |
| ### LAYOUT | |
| # Display images | |
| with gr.Row(): | |
| gt_img = gr.Image(label="Ground-truth image", interactive=True, key='gt_img') | |
| observed_img = gr.Image(label="Observed image", interactive=False, key='observed_img') | |
| model_a_out = gr.Image(label="RAM output", interactive=False, key='ram_out') | |
| model_b_out = gr.Image(label="DPIR output", interactive=False, key='dpir_out') | |
| # Manage datasets and display metric values | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=160): | |
| run_button = gr.Button("Demo on above image", size='md') | |
| with gr.Row(): | |
| load_button = gr.Button("Run on index image from dataset", size='md') | |
| load_random_button = gr.Button("Run on random image from dataset", size='md') | |
| with gr.Column(scale=1, min_width=160): | |
| observed_metrics = gr.Textbox(label="Observed metric", lines=2, key='metrics') | |
| with gr.Column(scale=1, min_width=160): | |
| out_a_metric = gr.Textbox(label="RAM output metrics", lines=2, key='ram_metrics') | |
| with gr.Column(scale=1, min_width=160): | |
| out_b_metric = gr.Textbox(label="DPIR output metrics", lines=2, key='dpir_metrics') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| choose_physics = gr.Radio(choices=available_physics, | |
| label="Physics", | |
| value=physics.name) | |
| choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, | |
| label="Datasets", | |
| value=dataset.name) | |
| idx_slider = gr.Slider(minimum=0, maximum=len(dataset) - 1, step=1, label="Sample index", | |
| key='idx_slider') | |
| # with gr.Column(scale=1): | |
| # with gr.Row(): | |
| # key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), | |
| # label="Updatable Key") | |
| # value_text = gr.Textbox(label="Update Value") | |
| # update_button = gr.Button("Manually update parameter value", size='md') | |
| with gr.Column(scale=1): | |
| physics_params = gr.Textbox(label="Physics parameters", | |
| lines=5, | |
| value=physics.display_saved_params()) | |
| ### Event listeners | |
| choose_dataset.change(fn=get_dataset, | |
| inputs=choose_dataset, | |
| outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder]) | |
| choose_physics.change(fn=get_physics_on_DEVICE_STR, | |
| inputs=choose_physics, | |
| outputs=[physics_placeholder]) | |
| # update_button.click(fn=physics.update_and_display_params, | |
| # inputs=[key_selector, value_text], outputs=physics_params) | |
| run_button.click(fn=generate_imgs_from_user_partial, | |
| inputs=[gt_img, | |
| physics_placeholder, | |
| # use_generator_button, | |
| model_b_placeholder], | |
| outputs=[gt_img, observed_img, model_a_out, model_b_out, | |
| physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
| load_button.click(fn=generate_imgs_from_dataset_partial, | |
| inputs=[dataset_placeholder, | |
| idx_slider, | |
| physics_placeholder, | |
| # use_generator_button, | |
| model_b_placeholder], | |
| outputs=[gt_img, observed_img, model_a_out, model_b_out, | |
| physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
| load_random_button.click(fn=generate_random_imgs_from_dataset_partial, | |
| inputs=[dataset_placeholder, | |
| physics_placeholder, | |
| # use_generator_button, | |
| model_b_placeholder], | |
| outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out, | |
| physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
| interface.launch() |