MSD / msd_infer.py
root
Initial clean upload: checkpoint + scripts + PNG via LFS
5e7715d
import argparse
import os
import sys
from pathlib import Path
import torch
from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
AutoencoderKL,
DDPMScheduler,
)
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
# --- Crucial: Import Mamba utilities ---
# Ensure msd_utils.py is in the same directory or Python path
try:
from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba
print("Successfully imported Mamba utilities from msd_utils.py")
except ImportError as e:
print(f"ERROR: Could not import from msd_utils.py. Make sure it's in the same directory.")
print(f"Import Error: {e}")
sys.exit(1)
except Exception as e:
print(f"ERROR: An unexpected error occurred while importing msd_utils.py: {e}")
sys.exit(1)
# --- End Mamba Import ---
def parse_args():
parser = argparse.ArgumentParser(description="Generate images using a fine-tuned Stable Diffusion Mamba UNet checkpoint.")
parser.add_argument(
"--base_model", type=str, default="runwayml/stable-diffusion-v1-5",
help="Path or Hub ID of the base Stable Diffusion model used for training (e.g., 'runwayml/stable-diffusion-v1-5')."
)
parser.add_argument(
"--checkpoint_dir", type=str, required=True,
help="Path to the specific checkpoint directory (e.g., 'sd-mamba-mscoco-urltext-5k-run1/checkpoint-5000')."
)
parser.add_argument(
"--unet_subfolder", type=str, default="unet_mamba",
help="Name of the subfolder within the checkpoint directory containing the saved UNet weights."
)
parser.add_argument(
"--prompt", type=str, default="A photo of an astronaut riding a horse on the moon",
help="Text prompt for image generation."
)
parser.add_argument(
"--output_path", type=str, default="generated_image_mamba.png",
help="Path to save the generated image."
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for generation ('cuda' or 'cpu')."
)
parser.add_argument(
"--seed", type=int, default=None,
help="Optional random seed for reproducibility."
)
parser.add_argument(
"--num_inference_steps", type=int, default=30,
help="Number of denoising steps."
)
parser.add_argument(
"--guidance_scale", type=float, default=7.5,
help="Scale for classifier-free guidance."
)
# --- Mamba Parameters (MUST match training) ---
parser.add_argument(
"--mamba_d_state", type=int, default=16, required=True, # Require to ensure user provides it
help="Mamba ssm state dimension used during training."
)
parser.add_argument(
"--mamba_d_conv", type=int, default=4, required=True, # Require to ensure user provides it
help="Mamba ssm convolution dimension used during training."
)
parser.add_argument(
"--mamba_expand", type=int, default=2, required=True, # Require to ensure user provides it
help="Mamba ssm expansion factor used during training."
)
# --- End Mamba Parameters ---
parser.add_argument(
"--pipeline_dtype", type=str, default="float32", choices=["float32", "float16"],
help="Run pipeline inference in float32 or float16. float32 is generally more stable."
)
args = parser.parse_args()
return args
def main():
args = parse_args()
print(f"--- Configuration ---")
print(f"Base Model: {args.base_model}")
print(f"Checkpoint Dir: {args.checkpoint_dir}")
print(f"UNet Subfolder: {args.unet_subfolder}")
print(f"Prompt: '{args.prompt}'")
print(f"Output Path: {args.output_path}")
print(f"Device: {args.device}")
print(f"Seed: {args.seed}")
print(f"Inference Steps: {args.num_inference_steps}")
print(f"Guidance Scale: {args.guidance_scale}")
print(f"Pipeline dtype: {args.pipeline_dtype}")
print(f"Mamba Params: d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}")
print(f"--------------------")
# Set device
device = torch.device(args.device)
pipeline_torch_dtype = torch.float32 if args.pipeline_dtype == "float32" else torch.float16
# Set seed if provided
generator = None
if args.seed is not None:
generator = torch.Generator(device=device).manual_seed(args.seed)
print(f"Using random seed: {args.seed}")
# Prepare Mamba kwargs dictionary
mamba_kwargs = {
'd_state': args.mamba_d_state,
'd_conv': args.mamba_d_conv,
'expand': args.mamba_expand,
}
print("Prepared Mamba kwargs for UNet replacement.")
# --- 1. Load Base Components (Tokenizer, Scheduler, VAE, Text Encoder) ---
print(f"Loading base components from {args.base_model}...")
try:
tokenizer = CLIPTokenizer.from_pretrained(args.base_model, subfolder="tokenizer")
scheduler = DDPMScheduler.from_pretrained(args.base_model, subfolder="scheduler")
# Load VAE and Text Encoder in float32 for stability, move to device
vae = AutoencoderKL.from_pretrained(args.base_model, subfolder="vae", torch_dtype=torch.float32).to(device)
text_encoder = CLIPTextModel.from_pretrained(args.base_model, subfolder="text_encoder", torch_dtype=torch.float32).to(device)
print("Base components loaded.")
except Exception as e:
print(f"ERROR: Failed to load base components from {args.base_model}. Check path/name.")
print(f"Error details: {e}")
sys.exit(1)
# --- 2. Create Base UNet Structure ---
print(f"Creating UNet structure from {args.base_model} config...")
try:
unet_config = UNet2DConditionModel.load_config(args.base_model, subfolder="unet")
unet = UNet2DConditionModel.from_config(unet_config, torch_dtype=pipeline_torch_dtype) # Use target dtype here
print("Base UNet structure created.")
except Exception as e:
print(f"ERROR: Failed to create UNet structure from config {args.base_model}.")
print(f"Error details: {e}")
sys.exit(1)
# --- 3. Modify UNet Structure with Mamba ---
print(f"Replacing UNet Self-Attention with Mamba blocks (using provided parameters)...")
try:
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs)
print("UNet structure modified with Mamba blocks.")
except Exception as e:
print(f"ERROR: Failed during UNet modification with Mamba blocks.")
print(f"Error details: {e}")
sys.exit(1)
# --- 4. Load Fine-tuned UNet Weights ---
unet_weights_dir = Path(args.checkpoint_dir) / args.unet_subfolder
print(f"Attempting to load fine-tuned UNet weights from: {unet_weights_dir}")
if not unet_weights_dir.is_dir():
print(f"ERROR: UNet weights directory not found: {unet_weights_dir}")
print(f"Please ensure '--checkpoint_dir' points to the correct checkpoint folder (e.g., checkpoint-5000)")
print(f"and '--unet_subfolder' is correct (likely 'unet_mamba').")
sys.exit(1)
try:
# Load the state dict into the already modified unet structure
print(f"Loading state dict from {unet_weights_dir}...")
# Check for safetensors first, then bin
state_dict_path_safe = unet_weights_dir / "diffusion_pytorch_model.safetensors"
state_dict_path_bin = unet_weights_dir / "diffusion_pytorch_model.bin"
if state_dict_path_safe.exists():
from safetensors.torch import load_file
unet_state_dict = load_file(state_dict_path_safe, device="cpu")
print(f"Loaded state dict from {state_dict_path_safe}")
elif state_dict_path_bin.exists():
unet_state_dict = torch.load(state_dict_path_bin, map_location="cpu")
print(f"Loaded state dict from {state_dict_path_bin}")
else:
raise FileNotFoundError(f"Neither safetensors nor bin file found in {unet_weights_dir}")
# Load into the existing UNet object (which has the Mamba structure)
load_result = unet.load_state_dict(unet_state_dict, strict=True) # Use strict=True to catch mismatches
print(f"UNet state dict loaded successfully. Load result: {load_result}")
del unet_state_dict # Free memory
print("Fine-tuned UNet weights loaded.")
except Exception as e:
print(f"ERROR: Failed to load UNet weights from {unet_weights_dir}.")
print(f"Make sure the directory exists and contains the model weights ('diffusion_pytorch_model.safetensors' or '.bin').")
print(f"Also ensure Mamba parameters match those used during training.")
print(f"Error details: {e}")
sys.exit(1)
# Move UNet to device and set to eval mode
unet = unet.to(device)
unet.eval()
print("UNet moved to device and set to eval mode.")
# --- 5. Create Stable Diffusion Pipeline ---
print("Creating Stable Diffusion Pipeline with modified UNet...")
try:
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet, # Use the modified and loaded UNet
scheduler=scheduler,
safety_checker=None, # Disabled during training, keep disabled
feature_extractor=None,
requires_safety_checker=False,
)
# No need to move pipeline again if components are already on device
# pipeline = pipeline.to(device) # Components already moved
print("Pipeline created successfully.")
except Exception as e:
print(f"ERROR: Failed to create Stable Diffusion Pipeline.")
print(f"Error details: {e}")
sys.exit(1)
# --- 6. Generate Image ---
print(f"Generating image for prompt: '{args.prompt}'...")
try:
with torch.no_grad(): # Inference context
# Run inference in the specified precision
with torch.autocast(device_type=args.device.split(":")[0], dtype=pipeline_torch_dtype, enabled=(pipeline_torch_dtype != torch.float32)):
result = pipeline(
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=generator,
# Add negative prompt if needed: negative_prompt="..."
)
image = result.images[0]
print("Image generation complete.")
except Exception as e:
print(f"ERROR: Image generation failed.")
print(f"Error details: {e}")
sys.exit(1)
# --- 7. Save Image ---
try:
output_dir = Path(args.output_path).parent
output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists
image.save(args.output_path)
print(f"Image saved successfully to: {args.output_path}")
except Exception as e:
print(f"ERROR: Failed to save image to {args.output_path}.")
print(f"Error details: {e}")
sys.exit(1)
if __name__ == "__main__":
main()