| | import argparse |
| | import os |
| | import sys |
| | from pathlib import Path |
| | import torch |
| | from diffusers import ( |
| | StableDiffusionPipeline, |
| | UNet2DConditionModel, |
| | AutoencoderKL, |
| | DDPMScheduler, |
| | ) |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | from PIL import Image |
| |
|
| | |
| | |
| | try: |
| | from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba |
| | print("Successfully imported Mamba utilities from msd_utils.py") |
| | except ImportError as e: |
| | print(f"ERROR: Could not import from msd_utils.py. Make sure it's in the same directory.") |
| | print(f"Import Error: {e}") |
| | sys.exit(1) |
| | except Exception as e: |
| | print(f"ERROR: An unexpected error occurred while importing msd_utils.py: {e}") |
| | sys.exit(1) |
| | |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Generate images using a fine-tuned Stable Diffusion Mamba UNet checkpoint.") |
| | parser.add_argument( |
| | "--base_model", type=str, default="runwayml/stable-diffusion-v1-5", |
| | help="Path or Hub ID of the base Stable Diffusion model used for training (e.g., 'runwayml/stable-diffusion-v1-5')." |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_dir", type=str, required=True, |
| | help="Path to the specific checkpoint directory (e.g., 'sd-mamba-mscoco-urltext-5k-run1/checkpoint-5000')." |
| | ) |
| | parser.add_argument( |
| | "--unet_subfolder", type=str, default="unet_mamba", |
| | help="Name of the subfolder within the checkpoint directory containing the saved UNet weights." |
| | ) |
| | parser.add_argument( |
| | "--prompt", type=str, default="a garden", |
| | help="Text prompt for image generation." |
| | ) |
| | parser.add_argument( |
| | "--output_path", type=str, default="generated_image_mamba.png", |
| | help="Path to save the generated image." |
| | ) |
| | parser.add_argument( |
| | "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
| | help="Device to use for generation ('cuda' or 'cpu')." |
| | ) |
| | parser.add_argument( |
| | "--seed", type=int, default=12345, |
| | help="Optional random seed for reproducibility." |
| | ) |
| | parser.add_argument( |
| | "--num_inference_steps", type=int, default=30, |
| | help="Number of denoising steps." |
| | ) |
| | parser.add_argument( |
| | "--guidance_scale", type=float, default=7.5, |
| | help="Scale for classifier-free guidance." |
| | ) |
| | |
| | parser.add_argument( |
| | "--width", type=int, default=512, |
| | help="Width of the generated image." |
| | ) |
| | parser.add_argument( |
| | "--height", type=int, default=512, |
| | help="Height of the generated image." |
| | ) |
| | |
| | |
| | parser.add_argument( |
| | "--mamba_d_state", type=int, default=16, required=True, |
| | help="Mamba ssm state dimension used during training." |
| | ) |
| | parser.add_argument( |
| | "--mamba_d_conv", type=int, default=4, required=True, |
| | help="Mamba ssm convolution dimension used during training." |
| | ) |
| | parser.add_argument( |
| | "--mamba_expand", type=int, default=2, required=True, |
| | help="Mamba ssm expansion factor used during training." |
| | ) |
| | |
| | parser.add_argument( |
| | "--pipeline_dtype", type=str, default="float32", choices=["float32", "float16"], |
| | help="Run pipeline inference in float32 or float16. float32 is generally more stable." |
| | ) |
| |
|
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | print(f"--- Configuration ---") |
| | print(f"Base Model: {args.base_model}") |
| | print(f"Checkpoint Dir: {args.checkpoint_dir}") |
| | print(f"UNet Subfolder: {args.unet_subfolder}") |
| | print(f"Prompt: '{args.prompt}'") |
| | print(f"Output Path: {args.output_path}") |
| | print(f"Device: {args.device}") |
| | print(f"Seed: {args.seed}") |
| | print(f"Inference Steps: {args.num_inference_steps}") |
| | print(f"Guidance Scale: {args.guidance_scale}") |
| | print(f"Pipeline dtype: {args.pipeline_dtype}") |
| | |
| | print(f"Resolution: {args.width}x{args.height}") |
| | |
| | print(f"Mamba Params: d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}") |
| | print(f"--------------------") |
| |
|
| | |
| | device = torch.device(args.device) |
| | pipeline_torch_dtype = torch.float32 if args.pipeline_dtype == "float32" else torch.float16 |
| |
|
| | |
| | generator = None |
| | if args.seed is not None: |
| | generator = torch.Generator(device=device).manual_seed(args.seed) |
| | print(f"Using random seed: {args.seed}") |
| |
|
| | |
| | mamba_kwargs = { |
| | 'd_state': args.mamba_d_state, |
| | 'd_conv': args.mamba_d_conv, |
| | 'expand': args.mamba_expand, |
| | } |
| | print("Prepared Mamba kwargs for UNet replacement.") |
| |
|
| | |
| | print(f"Loading base components from {args.base_model}...") |
| | try: |
| | tokenizer = CLIPTokenizer.from_pretrained(args.base_model, subfolder="tokenizer") |
| | scheduler = DDPMScheduler.from_pretrained(args.base_model, subfolder="scheduler") |
| | |
| | |
| | vae = AutoencoderKL.from_pretrained(args.base_model, subfolder="vae", torch_dtype=torch.float32).to(device) |
| | text_encoder = CLIPTextModel.from_pretrained(args.base_model, subfolder="text_encoder", torch_dtype=torch.float32).to(device) |
| | print("Base components loaded.") |
| | except Exception as e: |
| | print(f"ERROR: Failed to load base components from {args.base_model}. Check path/name.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | print(f"Creating UNet structure from {args.base_model} config...") |
| | try: |
| | unet_config = UNet2DConditionModel.load_config(args.base_model, subfolder="unet") |
| | |
| | |
| | unet = UNet2DConditionModel.from_config(unet_config, torch_dtype=pipeline_torch_dtype) |
| | print("Base UNet structure created.") |
| | except Exception as e: |
| | print(f"ERROR: Failed to create UNet structure from config {args.base_model}.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | print(f"Replacing UNet Self-Attention with Mamba blocks (using provided parameters)...") |
| | try: |
| | |
| | |
| | unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) |
| | print("UNet structure modified with Mamba blocks.") |
| | except Exception as e: |
| | print(f"ERROR: Failed during UNet modification with Mamba blocks.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | unet_weights_dir = Path(args.checkpoint_dir) / args.unet_subfolder |
| | print(f"Attempting to load fine-tuned UNet weights from: {unet_weights_dir}") |
| |
|
| | if not unet_weights_dir.is_dir(): |
| | print(f"ERROR: UNet weights directory not found: {unet_weights_dir}") |
| | print(f"Please ensure '--checkpoint_dir' points to the correct checkpoint folder (e.g., checkpoint-5000)") |
| | print(f"and '--unet_subfolder' is correct (likely 'unet_mamba').") |
| | sys.exit(1) |
| |
|
| | try: |
| | |
| | print(f"Loading state dict from {unet_weights_dir}...") |
| | |
| | state_dict_path_safe = unet_weights_dir / "diffusion_pytorch_model.safetensors" |
| | state_dict_path_bin = unet_weights_dir / "diffusion_pytorch_model.bin" |
| |
|
| | if state_dict_path_safe.exists(): |
| | from safetensors.torch import load_file |
| | unet_state_dict = load_file(state_dict_path_safe, device="cpu") |
| | print(f"Loaded state dict from {state_dict_path_safe}") |
| | elif state_dict_path_bin.exists(): |
| | unet_state_dict = torch.load(state_dict_path_bin, map_location="cpu") |
| | print(f"Loaded state dict from {state_dict_path_bin}") |
| | else: |
| | raise FileNotFoundError(f"Neither safetensors nor bin file found in {unet_weights_dir}") |
| |
|
| | |
| | load_result = unet.load_state_dict(unet_state_dict, strict=True) |
| | print(f"UNet state dict loaded successfully. Load result: {load_result}") |
| | del unet_state_dict |
| | print("Fine-tuned UNet weights loaded.") |
| |
|
| | except Exception as e: |
| | print(f"ERROR: Failed to load UNet weights from {unet_weights_dir}.") |
| | print(f"Make sure the directory exists and contains the model weights ('diffusion_pytorch_model.safetensors' or '.bin').") |
| | print(f"Also ensure Mamba parameters match those used during training.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | unet = unet.to(device) |
| | unet.eval() |
| | print("UNet moved to device and set to eval mode.") |
| |
|
| |
|
| | |
| | print("Creating Stable Diffusion Pipeline with modified UNet...") |
| | try: |
| | pipeline = StableDiffusionPipeline( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=None, |
| | feature_extractor=None, |
| | requires_safety_checker=False, |
| | ) |
| | |
| | |
| | print("Pipeline created successfully.") |
| | except Exception as e: |
| | print(f"ERROR: Failed to create Stable Diffusion Pipeline.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | print(f"Generating image for prompt: '{args.prompt}'...") |
| | try: |
| | with torch.no_grad(): |
| | |
| | with torch.autocast(device_type=args.device.split(":")[0], dtype=pipeline_torch_dtype, enabled=(pipeline_torch_dtype != torch.float32)): |
| | result = pipeline( |
| | prompt=args.prompt, |
| | num_inference_steps=args.num_inference_steps, |
| | guidance_scale=args.guidance_scale, |
| | generator=generator, |
| | width=args.width, |
| | height=args.height, |
| | |
| | ) |
| | image = result.images[0] |
| |
|
| | print("Image generation complete.") |
| |
|
| | except Exception as e: |
| | print(f"ERROR: Image generation failed.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| |
|
| | |
| | try: |
| | output_dir = Path(args.output_path).parent |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | image.save(args.output_path) |
| | print(f"Image saved successfully to: {args.output_path}") |
| | except Exception as e: |
| | print(f"ERROR: Failed to save image to {args.output_path}.") |
| | print(f"Error details: {e}") |
| | sys.exit(1) |
| |
|
| | if __name__ == "__main__": |
| | main() |