File size: 1,782 Bytes
ae53881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9647f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: UTF-8 -*-
"""
@Time : 30/05/2025 19:24
@Author : xiaoguangliang
@File : unconditional_diffusion_inference.py
@Project : Faice_text2face
"""
import torch
import random
import numpy as np
from diffusers import DDPMPipeline
from accelerate import Accelerator
import gradio as gr
import spaces
import PIL.Image
from loguru import logger

from utils import timer

model_path = 'Ngene787/Faice_unconditional_diffusion'

if torch.backends.mps.is_available():
    accelerator = Accelerator(gradient_accumulation_steps=1)
else:
    accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)

logger.info("Loading model ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32
pipe = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch_dtype,
                                    low_cpu_mem_usage=True
                                    )
pipe = pipe.to(device)

pipe = accelerator.prepare(pipe)
# Enable memory-efficient attention
# pipe.enable_xformers_memory_efficient_attention()


MAX_SEED = np.iinfo(np.int32).max


@spaces.GPU(duration=65)
def inference_unconditional(seed,
                            randomize_seed=False,
                            num_inference_steps=20,
                            progress=gr.Progress(track_tqdm=True), ):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)

    logger.info('Generating image ...')
    with timer("inference"):
        image = pipe(
            batch_size=1,
            generator=generator,
            num_inference_steps=num_inference_steps,
            output_type="np",
        ).images[0]
    return image