Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import PIL.Image | |
| from pathlib import Path | |
| import pandas as pd | |
| from diffusers.pipelines import StableDiffusionPipeline | |
| import torch | |
| import argparse | |
| import os | |
| import warnings | |
| from safetensors.torch import load_file | |
| import yaml | |
| warnings.filterwarnings("ignore") | |
| ################################################################################ | |
| # Define the default parameters | |
| OUTPUT_DIR = "OUTPUT" | |
| cuda_device = 1 | |
| device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" | |
| TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" | |
| INFO_ABOUT_TEXT_PROMPT = "Text prompt for generating the X-Ray" | |
| INFO_ABOUT_GUIDANCE_SCALE = "Guidance Scale determines the strength of the guidance signal" | |
| INFO_ABOUT_INFERENCE_STEPS = "Number of inference steps to use for generating the X-ray" | |
| EXAMPLE_TEXT_PROMPTS = [ | |
| "No acute cardiopulmonary abnormality.", | |
| "Normal chest radiograph.", | |
| "No acute intrathoracic process.", | |
| "Mild pulmonary edema.", | |
| "No focal consolidation concerning for pneumonia", | |
| "No radiographic evidence for acute cardiopulmonary process", | |
| ] | |
| ################################################################################ | |
| def load_adapted_unet(unet_pretraining_type, pipe): | |
| """ | |
| Loads the adapted U-Net for the selected PEFT Type | |
| Parameters: | |
| unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
| pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
| Returns: | |
| None | |
| """ | |
| sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
| exp_path = "" | |
| if unet_pretraining_type == "freeze": | |
| pass | |
| elif unet_pretraining_type == "svdiff": | |
| print("SV-DIFF UNET") | |
| pipe.unet = load_unet_for_svdiff( | |
| sd_folder_path, | |
| spectral_shifts_ckpt=os.path.join( | |
| os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" | |
| ), | |
| subfolder="unet", | |
| ) | |
| for module in pipe.unet.modules(): | |
| if hasattr(module, "perform_svd"): | |
| module.perform_svd() | |
| elif unet_pretraining_type == "lorav2": | |
| exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") | |
| pipe.unet.load_attn_procs(exp_path) | |
| else: | |
| # exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" | |
| # state_dict = load_file(exp_path) | |
| state_dict = load_file( | |
| unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" | |
| ) | |
| print(pipe.unet.load_state_dict(state_dict, strict=False)) | |
| def loadSDModel(unet_pretraining_type, cuda_device): | |
| """ | |
| Loads the Stable Diffusion Model for the selected PEFT Type | |
| Parameters: | |
| unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
| cuda_device (str): The CUDA device to use for generating the X-ray | |
| Returns: | |
| pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
| """ | |
| sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
| pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") | |
| load_adapted_unet(unet_pretraining_type, pipe) | |
| pipe.safety_checker = None | |
| return pipe | |
| def _predict_using_default_params(): | |
| # Defining the default parameters | |
| unet_pretraining_type = "full" | |
| input_text = "No acute cardiopulmonary abnormality." | |
| guidance_scale = 4 | |
| num_inference_steps = 75 | |
| device = "0" | |
| OUTPUT_DIR = "OUTPUT" | |
| BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format( | |
| unet_pretraining_type | |
| ) | |
| NUM_TUNABLE_PARAMS = { | |
| "full": 86, | |
| "attention": 26.7, | |
| "bias": 0.343, | |
| "norm": 0.2, | |
| "norm_bias_attention": 26.7, | |
| "lorav2": 0.8, | |
| "svdiff": 0.222, | |
| "difffit": 0.581, | |
| } | |
| cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
| print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) | |
| sd_pipeline = loadSDModel( | |
| unet_pretraining_type=unet_pretraining_type, | |
| cuda_device=cuda_device, | |
| ) | |
| sd_pipeline.to(cuda_device) | |
| result_image = sd_pipeline( | |
| prompt=input_text, | |
| height=224, | |
| width=224, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| result_pil_image = result_image["images"][0] | |
| # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type | |
| df = pd.DataFrame( | |
| { | |
| "Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), | |
| "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), | |
| } | |
| ) | |
| print(df) | |
| df = df[ | |
| df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type]) | |
| ].reset_index(drop=True) | |
| bar_plot = gr.BarPlot( | |
| value=df, | |
| x="Fine-Tuning Strategy", | |
| y="Number of Tunable Parameters", | |
| title=BARPLOT_TITLE, | |
| vertical=True, | |
| height=300, | |
| width=300, | |
| interactive=True, | |
| ) | |
| return result_pil_image, bar_plot | |
| def predict( | |
| unet_pretraining_type, | |
| input_text, | |
| guidance_scale=4, | |
| num_inference_steps=75, | |
| device="0", | |
| OUTPUT_DIR="OUTPUT", | |
| ): | |
| """ | |
| Generates a Chest X-ray using the selected PEFT Type, input text prompt, guidance scale, and number of inference steps | |
| Parameters: | |
| unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
| input_text (str): The text prompt to use for generating the X-ray | |
| guidance_scale (int): The guidance scale to use for generating the X-ray | |
| num_inference_steps (int): The number of inference steps to use for generating the X-ray | |
| device (str): The CUDA device to use for generating the X-ray | |
| OUTPUT_DIR (str): The output directory to save the generated X-ray | |
| Returns: | |
| result_pil_image (PIL.Image): The generated X-ray image | |
| bar_plot (gr.BarPlot): The number of tunable parameters for the selected PEFT Type | |
| """ | |
| # Run the _predict_using_default_params() function to generate a defualt X-ray output | |
| # result_pil_image, bar_plot = _predict_using_default_params() | |
| try: | |
| BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format( | |
| unet_pretraining_type | |
| ) | |
| NUM_TUNABLE_PARAMS = { | |
| "full": 86, | |
| "attention": 26.7, | |
| "bias": 0.343, | |
| "norm": 0.2, | |
| "norm_bias_attention": 26.7, | |
| "lorav2": 0.8, | |
| "svdiff": 0.222, | |
| "difffit": 0.581, | |
| } | |
| cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
| print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) | |
| sd_pipeline = loadSDModel( | |
| unet_pretraining_type=unet_pretraining_type, | |
| cuda_device=cuda_device, | |
| ) | |
| sd_pipeline.to(cuda_device) | |
| result_image = sd_pipeline( | |
| prompt=input_text, | |
| height=224, | |
| width=224, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| result_pil_image = result_image["images"][0] | |
| # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type | |
| df = pd.DataFrame( | |
| { | |
| "Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), | |
| "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), | |
| } | |
| ) | |
| print(df) | |
| df = df[ | |
| df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type]) | |
| ].reset_index(drop=True) | |
| bar_plot = gr.BarPlot( | |
| value=df, | |
| x="Fine-Tuning Strategy", | |
| y="Number of Tunable Parameters", | |
| title=BARPLOT_TITLE, | |
| vertical=True, | |
| height=300, | |
| width=300, | |
| interactive=True, | |
| ) | |
| return result_pil_image, bar_plot | |
| except: | |
| return _predict_using_default_params() | |
| # Create a Gradio interface | |
| """ | |
| Input Parameters: | |
| 1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray | |
| 2. Input Text: (Textbox) The text prompt to use for generating the X-ray | |
| 3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray | |
| 4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray | |
| Output Parameters: | |
| 1. Generated X-ray Image: (Image) The generated X-ray image | |
| 2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type | |
| """ | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Dropdown( | |
| ["full", "difffit", "norm", "bias", "attention", "norm_bias_attention"], | |
| value="full", | |
| label="PEFT Type", | |
| ), | |
| gr.Dropdown( | |
| EXAMPLE_TEXT_PROMPTS, | |
| label="Input Text", | |
| info=INFO_ABOUT_TEXT_PROMPT, | |
| value=EXAMPLE_TEXT_PROMPTS[0], | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| info=INFO_ABOUT_GUIDANCE_SCALE, | |
| label="Guidance Scale", | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=75, | |
| step=1, | |
| info=INFO_ABOUT_INFERENCE_STEPS, | |
| label="Num Inference Steps", | |
| ), | |
| ], | |
| outputs=[gr.Image(type="pil"), gr.BarPlot()], | |
| live=True, | |
| analytics_enabled=False, | |
| title=TITLE, | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch(share=True) | |