from config import RunConfig import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline import torch.nn as nn def load_stable_diffusion_model(config: RunConfig): device = torch.device('cpu') if config.sd_2_1: stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base" else: stable_diffusion_version = "CompVis/stable-diffusion-v1-4" # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device) stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device) return stable