import argparse import os import sys from pathlib import Path import torch from diffusers import ( StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler, ) from transformers import CLIPTextModel, CLIPTokenizer from PIL import Image # --- Crucial: Import Mamba utilities --- # Ensure msd_utils.py is in the same directory or Python path try: from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba print("Successfully imported Mamba utilities from msd_utils.py") except ImportError as e: print(f"ERROR: Could not import from msd_utils.py. Make sure it's in the same directory.") print(f"Import Error: {e}") sys.exit(1) except Exception as e: print(f"ERROR: An unexpected error occurred while importing msd_utils.py: {e}") sys.exit(1) # --- End Mamba Import --- def parse_args(): parser = argparse.ArgumentParser(description="Generate images using a fine-tuned Stable Diffusion Mamba UNet checkpoint.") parser.add_argument( "--base_model", type=str, default="runwayml/stable-diffusion-v1-5", help="Path or Hub ID of the base Stable Diffusion model used for training (e.g., 'runwayml/stable-diffusion-v1-5')." ) parser.add_argument( "--checkpoint_dir", type=str, required=True, help="Path to the specific checkpoint directory (e.g., 'sd-mamba-mscoco-urltext-5k-run1/checkpoint-5000')." ) parser.add_argument( "--unet_subfolder", type=str, default="unet_mamba", help="Name of the subfolder within the checkpoint directory containing the saved UNet weights." ) parser.add_argument( "--prompt", type=str, default="A photo of an astronaut riding a horse on the moon", help="Text prompt for image generation." ) parser.add_argument( "--output_path", type=str, default="generated_image_mamba.png", help="Path to save the generated image." ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for generation ('cuda' or 'cpu')." ) parser.add_argument( "--seed", type=int, default=None, help="Optional random seed for reproducibility." ) parser.add_argument( "--num_inference_steps", type=int, default=30, help="Number of denoising steps." ) parser.add_argument( "--guidance_scale", type=float, default=7.5, help="Scale for classifier-free guidance." ) # --- Mamba Parameters (MUST match training) --- parser.add_argument( "--mamba_d_state", type=int, default=16, required=True, # Require to ensure user provides it help="Mamba ssm state dimension used during training." ) parser.add_argument( "--mamba_d_conv", type=int, default=4, required=True, # Require to ensure user provides it help="Mamba ssm convolution dimension used during training." ) parser.add_argument( "--mamba_expand", type=int, default=2, required=True, # Require to ensure user provides it help="Mamba ssm expansion factor used during training." ) # --- End Mamba Parameters --- parser.add_argument( "--pipeline_dtype", type=str, default="float32", choices=["float32", "float16"], help="Run pipeline inference in float32 or float16. float32 is generally more stable." ) args = parser.parse_args() return args def main(): args = parse_args() print(f"--- Configuration ---") print(f"Base Model: {args.base_model}") print(f"Checkpoint Dir: {args.checkpoint_dir}") print(f"UNet Subfolder: {args.unet_subfolder}") print(f"Prompt: '{args.prompt}'") print(f"Output Path: {args.output_path}") print(f"Device: {args.device}") print(f"Seed: {args.seed}") print(f"Inference Steps: {args.num_inference_steps}") print(f"Guidance Scale: {args.guidance_scale}") print(f"Pipeline dtype: {args.pipeline_dtype}") print(f"Mamba Params: d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}") print(f"--------------------") # Set device device = torch.device(args.device) pipeline_torch_dtype = torch.float32 if args.pipeline_dtype == "float32" else torch.float16 # Set seed if provided generator = None if args.seed is not None: generator = torch.Generator(device=device).manual_seed(args.seed) print(f"Using random seed: {args.seed}") # Prepare Mamba kwargs dictionary mamba_kwargs = { 'd_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand, } print("Prepared Mamba kwargs for UNet replacement.") # --- 1. Load Base Components (Tokenizer, Scheduler, VAE, Text Encoder) --- print(f"Loading base components from {args.base_model}...") try: tokenizer = CLIPTokenizer.from_pretrained(args.base_model, subfolder="tokenizer") scheduler = DDPMScheduler.from_pretrained(args.base_model, subfolder="scheduler") # Load VAE and Text Encoder in float32 for stability, move to device vae = AutoencoderKL.from_pretrained(args.base_model, subfolder="vae", torch_dtype=torch.float32).to(device) text_encoder = CLIPTextModel.from_pretrained(args.base_model, subfolder="text_encoder", torch_dtype=torch.float32).to(device) print("Base components loaded.") except Exception as e: print(f"ERROR: Failed to load base components from {args.base_model}. Check path/name.") print(f"Error details: {e}") sys.exit(1) # --- 2. Create Base UNet Structure --- print(f"Creating UNet structure from {args.base_model} config...") try: unet_config = UNet2DConditionModel.load_config(args.base_model, subfolder="unet") unet = UNet2DConditionModel.from_config(unet_config, torch_dtype=pipeline_torch_dtype) # Use target dtype here print("Base UNet structure created.") except Exception as e: print(f"ERROR: Failed to create UNet structure from config {args.base_model}.") print(f"Error details: {e}") sys.exit(1) # --- 3. Modify UNet Structure with Mamba --- print(f"Replacing UNet Self-Attention with Mamba blocks (using provided parameters)...") try: unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) print("UNet structure modified with Mamba blocks.") except Exception as e: print(f"ERROR: Failed during UNet modification with Mamba blocks.") print(f"Error details: {e}") sys.exit(1) # --- 4. Load Fine-tuned UNet Weights --- unet_weights_dir = Path(args.checkpoint_dir) / args.unet_subfolder print(f"Attempting to load fine-tuned UNet weights from: {unet_weights_dir}") if not unet_weights_dir.is_dir(): print(f"ERROR: UNet weights directory not found: {unet_weights_dir}") print(f"Please ensure '--checkpoint_dir' points to the correct checkpoint folder (e.g., checkpoint-5000)") print(f"and '--unet_subfolder' is correct (likely 'unet_mamba').") sys.exit(1) try: # Load the state dict into the already modified unet structure print(f"Loading state dict from {unet_weights_dir}...") # Check for safetensors first, then bin state_dict_path_safe = unet_weights_dir / "diffusion_pytorch_model.safetensors" state_dict_path_bin = unet_weights_dir / "diffusion_pytorch_model.bin" if state_dict_path_safe.exists(): from safetensors.torch import load_file unet_state_dict = load_file(state_dict_path_safe, device="cpu") print(f"Loaded state dict from {state_dict_path_safe}") elif state_dict_path_bin.exists(): unet_state_dict = torch.load(state_dict_path_bin, map_location="cpu") print(f"Loaded state dict from {state_dict_path_bin}") else: raise FileNotFoundError(f"Neither safetensors nor bin file found in {unet_weights_dir}") # Load into the existing UNet object (which has the Mamba structure) load_result = unet.load_state_dict(unet_state_dict, strict=True) # Use strict=True to catch mismatches print(f"UNet state dict loaded successfully. Load result: {load_result}") del unet_state_dict # Free memory print("Fine-tuned UNet weights loaded.") except Exception as e: print(f"ERROR: Failed to load UNet weights from {unet_weights_dir}.") print(f"Make sure the directory exists and contains the model weights ('diffusion_pytorch_model.safetensors' or '.bin').") print(f"Also ensure Mamba parameters match those used during training.") print(f"Error details: {e}") sys.exit(1) # Move UNet to device and set to eval mode unet = unet.to(device) unet.eval() print("UNet moved to device and set to eval mode.") # --- 5. Create Stable Diffusion Pipeline --- print("Creating Stable Diffusion Pipeline with modified UNet...") try: pipeline = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, # Use the modified and loaded UNet scheduler=scheduler, safety_checker=None, # Disabled during training, keep disabled feature_extractor=None, requires_safety_checker=False, ) # No need to move pipeline again if components are already on device # pipeline = pipeline.to(device) # Components already moved print("Pipeline created successfully.") except Exception as e: print(f"ERROR: Failed to create Stable Diffusion Pipeline.") print(f"Error details: {e}") sys.exit(1) # --- 6. Generate Image --- print(f"Generating image for prompt: '{args.prompt}'...") try: with torch.no_grad(): # Inference context # Run inference in the specified precision with torch.autocast(device_type=args.device.split(":")[0], dtype=pipeline_torch_dtype, enabled=(pipeline_torch_dtype != torch.float32)): result = pipeline( prompt=args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=generator, # Add negative prompt if needed: negative_prompt="..." ) image = result.images[0] print("Image generation complete.") except Exception as e: print(f"ERROR: Image generation failed.") print(f"Error details: {e}") sys.exit(1) # --- 7. Save Image --- try: output_dir = Path(args.output_path).parent output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists image.save(args.output_path) print(f"Image saved successfully to: {args.output_path}") except Exception as e: print(f"ERROR: Failed to save image to {args.output_path}.") print(f"Error details: {e}") sys.exit(1) if __name__ == "__main__": main()