File size: 1,919 Bytes
aef244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import spaces
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline

from config import MODEL_ID, DEFAULT_GUIDANCE, DEFAULT_HEIGHT, DEFAULT_PROMPT, DEFAULT_WIDTH


def _load_pipeline() -> StableDiffusionPipeline:
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    pipe = StableDiffusionPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        safety_checker=None,
        requires_safety_checker=False,
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    if torch.cuda.is_available():
        pipe = pipe.to("cuda")
    pipe.set_progress_bar_config(disable=True)
    pipe.enable_xformers_memory_efficient_attention()
    return pipe


pipe = _load_pipeline()


@spaces.GPU(duration=1500)
def compile_unet():
    with spaces.aoti_capture(pipe.unet) as call:
        pipe(
            prompt=DEFAULT_PROMPT,
            negative_prompt=None,
            guidance_scale=DEFAULT_GUIDANCE,
            num_inference_steps=5,
            width=DEFAULT_WIDTH,
            height=DEFAULT_HEIGHT,
            num_images_per_prompt=1,
        )
    exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs)
    return spaces.aoti_compile(exported)


compiled_unet = compile_unet()
spaces.aoti_apply(compiled_unet, pipe.unet)


@spaces.GPU(duration=90)
def run_generation(
    prompt: str,
    negative_prompt: str | None,
    guidance_scale: float,
    num_inference_steps: int,
    width: int,
    height: int,
    num_images: int,
    generator: torch.Generator,
):
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        num_images_per_prompt=num_images,
        generator=generator,
    )
    return result.images