File size: 598 Bytes
8f72b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from config import RunConfig
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
import torch.nn as nn

def load_stable_diffusion_model(config: RunConfig):
    device = torch.device('cpu')

    if config.sd_2_1:
        stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
    else:
        stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
    # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
    stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
    return stable