MorphGuard / scripts /download_stable_diffusion.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
3.76 kB
#!/usr/bin/env python3
"""Helper script to download Stable Diffusion and LDM models for MorphGuard"""
import os
import argparse
import torch
from diffusers import StableDiffusionImg2ImgPipeline, DiffusionPipeline
def main():
parser = argparse.ArgumentParser(description="Download Stable Diffusion models for MorphGuard")
parser.add_argument("--use-cpu", action="store_true", help="Force CPU usage (not recommended)")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker for faster operation")
parser.add_argument("--no-fp16", action="store_true", help="Disable fp16 even when running on GPU")
args = parser.parse_args()
# Determine device
if args.use_cpu:
device = "cpu"
print("WARNING: Forcing CPU usage. This will be very slow for inference.")
use_fp16 = False
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
use_fp16 = device == "cuda" and not args.no_fp16
if device == "cpu":
print("WARNING: CUDA not available. Downloads will be configured for CPU but inference will be slow.")
else:
if use_fp16:
print(f"Using {device} with fp16 precision for model downloads (faster, lower memory usage)")
else:
print(f"Using {device} with fp32 precision (higher accuracy but more memory usage)")
# Models to download
models = [
{
"name": "Stable Diffusion v1.4",
"path": "models/diffusion/stable-diffusion-v1-4",
"model_id": "CompVis/stable-diffusion-v1-4",
"pipeline": StableDiffusionImg2ImgPipeline
},
{
"name": "LDM CelebA-HQ",
"path": "models/latent-diffusion/ldm-celebahq-256",
"model_id": "CompVis/ldm-celebahq-256",
"pipeline": DiffusionPipeline
}
]
# Ensure directories exist
for model in models:
os.makedirs(model["path"], exist_ok=True)
# Download each model
for model in models:
print(f"Downloading {model['name']} to {model['path']}...")
# Configure pipeline based on device and args
kwargs = {}
if args.no_safety_checker:
kwargs["safety_checker"] = None
if use_fp16:
# Use fp16 on CUDA if available (faster, lower memory usage)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
print(f" - Using fp16 precision for {model['name']}")
# Create pipeline
try:
pipeline = model["pipeline"].from_pretrained(
model["model_id"],
**kwargs
)
# Move to device
pipeline = pipeline.to(device)
# Save to local path
pipeline.save_pretrained(model["path"])
print(f"✅ Successfully downloaded {model['name']}")
except Exception as e:
print(f"❌ Error downloading {model['name']}: {str(e)}")
print("\nSetup complete! Models were downloaded to:")
for model in models:
print(f" - {model['path']}")
if device == "cuda":
print("\nYour GPU has been detected and models configured for optimal performance.")
if use_fp16:
print("Using fp16 precision which offers:")
print(" - Up to 2-3x faster inference")
print(" - Half the memory usage")
print(" - Ability to run larger models or batch sizes")
else:
print("\nNo GPU detected. Consider running on a machine with CUDA support for better performance.")
if __name__ == "__main__":
main()