|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
import torch |
|
|
from diffusers import ( |
|
|
StableDiffusionPipeline, |
|
|
UNet2DConditionModel, |
|
|
AutoencoderKL, |
|
|
DDPMScheduler, |
|
|
) |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba |
|
|
print("Successfully imported Mamba utilities from msd_utils.py") |
|
|
except ImportError as e: |
|
|
print(f"ERROR: Could not import from msd_utils.py. Make sure it's in the same directory.") |
|
|
print(f"Import Error: {e}") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
print(f"ERROR: An unexpected error occurred while importing msd_utils.py: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Generate images using a fine-tuned Stable Diffusion Mamba UNet checkpoint.") |
|
|
parser.add_argument( |
|
|
"--base_model", type=str, default="runwayml/stable-diffusion-v1-5", |
|
|
help="Path or Hub ID of the base Stable Diffusion model used for training (e.g., 'runwayml/stable-diffusion-v1-5')." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint_dir", type=str, required=True, |
|
|
help="Path to the specific checkpoint directory (e.g., 'sd-mamba-mscoco-urltext-5k-run1/checkpoint-5000')." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--unet_subfolder", type=str, default="unet_mamba", |
|
|
help="Name of the subfolder within the checkpoint directory containing the saved UNet weights." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", type=str, default="A photo of an astronaut riding a horse on the moon", |
|
|
help="Text prompt for image generation." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_path", type=str, default="generated_image_mamba.png", |
|
|
help="Path to save the generated image." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
|
|
help="Device to use for generation ('cuda' or 'cpu')." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", type=int, default=None, |
|
|
help="Optional random seed for reproducibility." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_inference_steps", type=int, default=30, |
|
|
help="Number of denoising steps." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--guidance_scale", type=float, default=7.5, |
|
|
help="Scale for classifier-free guidance." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--mamba_d_state", type=int, default=16, required=True, |
|
|
help="Mamba ssm state dimension used during training." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mamba_d_conv", type=int, default=4, required=True, |
|
|
help="Mamba ssm convolution dimension used during training." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mamba_expand", type=int, default=2, required=True, |
|
|
help="Mamba ssm expansion factor used during training." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--pipeline_dtype", type=str, default="float32", choices=["float32", "float16"], |
|
|
help="Run pipeline inference in float32 or float16. float32 is generally more stable." |
|
|
) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
print(f"--- Configuration ---") |
|
|
print(f"Base Model: {args.base_model}") |
|
|
print(f"Checkpoint Dir: {args.checkpoint_dir}") |
|
|
print(f"UNet Subfolder: {args.unet_subfolder}") |
|
|
print(f"Prompt: '{args.prompt}'") |
|
|
print(f"Output Path: {args.output_path}") |
|
|
print(f"Device: {args.device}") |
|
|
print(f"Seed: {args.seed}") |
|
|
print(f"Inference Steps: {args.num_inference_steps}") |
|
|
print(f"Guidance Scale: {args.guidance_scale}") |
|
|
print(f"Pipeline dtype: {args.pipeline_dtype}") |
|
|
print(f"Mamba Params: d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}") |
|
|
print(f"--------------------") |
|
|
|
|
|
|
|
|
device = torch.device(args.device) |
|
|
pipeline_torch_dtype = torch.float32 if args.pipeline_dtype == "float32" else torch.float16 |
|
|
|
|
|
|
|
|
generator = None |
|
|
if args.seed is not None: |
|
|
generator = torch.Generator(device=device).manual_seed(args.seed) |
|
|
print(f"Using random seed: {args.seed}") |
|
|
|
|
|
|
|
|
mamba_kwargs = { |
|
|
'd_state': args.mamba_d_state, |
|
|
'd_conv': args.mamba_d_conv, |
|
|
'expand': args.mamba_expand, |
|
|
} |
|
|
print("Prepared Mamba kwargs for UNet replacement.") |
|
|
|
|
|
|
|
|
print(f"Loading base components from {args.base_model}...") |
|
|
try: |
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.base_model, subfolder="tokenizer") |
|
|
scheduler = DDPMScheduler.from_pretrained(args.base_model, subfolder="scheduler") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(args.base_model, subfolder="vae", torch_dtype=torch.float32).to(device) |
|
|
text_encoder = CLIPTextModel.from_pretrained(args.base_model, subfolder="text_encoder", torch_dtype=torch.float32).to(device) |
|
|
print("Base components loaded.") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to load base components from {args.base_model}. Check path/name.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print(f"Creating UNet structure from {args.base_model} config...") |
|
|
try: |
|
|
unet_config = UNet2DConditionModel.load_config(args.base_model, subfolder="unet") |
|
|
unet = UNet2DConditionModel.from_config(unet_config, torch_dtype=pipeline_torch_dtype) |
|
|
print("Base UNet structure created.") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to create UNet structure from config {args.base_model}.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print(f"Replacing UNet Self-Attention with Mamba blocks (using provided parameters)...") |
|
|
try: |
|
|
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) |
|
|
print("UNet structure modified with Mamba blocks.") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed during UNet modification with Mamba blocks.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
unet_weights_dir = Path(args.checkpoint_dir) / args.unet_subfolder |
|
|
print(f"Attempting to load fine-tuned UNet weights from: {unet_weights_dir}") |
|
|
|
|
|
if not unet_weights_dir.is_dir(): |
|
|
print(f"ERROR: UNet weights directory not found: {unet_weights_dir}") |
|
|
print(f"Please ensure '--checkpoint_dir' points to the correct checkpoint folder (e.g., checkpoint-5000)") |
|
|
print(f"and '--unet_subfolder' is correct (likely 'unet_mamba').") |
|
|
sys.exit(1) |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"Loading state dict from {unet_weights_dir}...") |
|
|
|
|
|
state_dict_path_safe = unet_weights_dir / "diffusion_pytorch_model.safetensors" |
|
|
state_dict_path_bin = unet_weights_dir / "diffusion_pytorch_model.bin" |
|
|
|
|
|
if state_dict_path_safe.exists(): |
|
|
from safetensors.torch import load_file |
|
|
unet_state_dict = load_file(state_dict_path_safe, device="cpu") |
|
|
print(f"Loaded state dict from {state_dict_path_safe}") |
|
|
elif state_dict_path_bin.exists(): |
|
|
unet_state_dict = torch.load(state_dict_path_bin, map_location="cpu") |
|
|
print(f"Loaded state dict from {state_dict_path_bin}") |
|
|
else: |
|
|
raise FileNotFoundError(f"Neither safetensors nor bin file found in {unet_weights_dir}") |
|
|
|
|
|
|
|
|
load_result = unet.load_state_dict(unet_state_dict, strict=True) |
|
|
print(f"UNet state dict loaded successfully. Load result: {load_result}") |
|
|
del unet_state_dict |
|
|
print("Fine-tuned UNet weights loaded.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to load UNet weights from {unet_weights_dir}.") |
|
|
print(f"Make sure the directory exists and contains the model weights ('diffusion_pytorch_model.safetensors' or '.bin').") |
|
|
print(f"Also ensure Mamba parameters match those used during training.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
unet = unet.to(device) |
|
|
unet.eval() |
|
|
print("UNet moved to device and set to eval mode.") |
|
|
|
|
|
|
|
|
|
|
|
print("Creating Stable Diffusion Pipeline with modified UNet...") |
|
|
try: |
|
|
pipeline = StableDiffusionPipeline( |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
safety_checker=None, |
|
|
feature_extractor=None, |
|
|
requires_safety_checker=False, |
|
|
) |
|
|
|
|
|
|
|
|
print("Pipeline created successfully.") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to create Stable Diffusion Pipeline.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print(f"Generating image for prompt: '{args.prompt}'...") |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
|
|
|
with torch.autocast(device_type=args.device.split(":")[0], dtype=pipeline_torch_dtype, enabled=(pipeline_torch_dtype != torch.float32)): |
|
|
result = pipeline( |
|
|
prompt=args.prompt, |
|
|
num_inference_steps=args.num_inference_steps, |
|
|
guidance_scale=args.guidance_scale, |
|
|
generator=generator, |
|
|
|
|
|
) |
|
|
image = result.images[0] |
|
|
|
|
|
print("Image generation complete.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR: Image generation failed.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
output_dir = Path(args.output_path).parent |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
image.save(args.output_path) |
|
|
print(f"Image saved successfully to: {args.output_path}") |
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to save image to {args.output_path}.") |
|
|
print(f"Error details: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |