File size: 2,004 Bytes
4d16190
 
 
 
 
 
 
 
da03eea
 
4d16190
2e3190d
da03eea
be0b220
793b863
4d16190
2fc94f8
 
4d16190
 
592a5b4
 
 
 
 
793b863
0559bb0
 
 
 
 
131320a
73c9ba5
50e67cf
 
da03eea
2e3190d
6994397
da03eea
0de7f44
2e3190d
4d16190
da03eea
 
 
0559bb0
ae53881
 
 
 
 
 
 
da03eea
 
 
 
793b863
2fc94f8
 
 
 
 
 
 
 
 
ea6c06b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# -*- coding: UTF-8 -*-
"""
@Time : 28/05/2025 11:56
@Author : xiaoguangliang
@File : stable_diffusion_inference.py
@Project : Faice_text2face
"""
import torch
import random
import numpy as np
from diffusers import StableDiffusionPipeline
from accelerate import Accelerator
import gradio as gr
import spaces
from loguru import logger

from utils import timer

model_path = 'Ngene787/Faice_text2face'

if torch.backends.mps.is_available():
    accelerator = Accelerator(gradient_accumulation_steps=1)
else:
    accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)

logger.info("Loading model ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch_dtype,
                                               low_cpu_mem_usage=True,
                                               # requires_safety_checker=False
                                               )
pipe = pipe.to(device)

pipe = accelerator.prepare(pipe)
# Enable memory-efficient attention
# pipe.enable_xformers_memory_efficient_attention()


MAX_SEED = np.iinfo(np.int32).max


@spaces.GPU(duration=65)
def inference_sd(prompt,
                 negative_prompt="",
                 seed=0,
                 randomize_seed=False,
                 guidance_scale=7.5,
                 num_inference_steps=20,
                 progress=gr.Progress(track_tqdm=True), ):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)

    logger.info('Generating image ...')
    with timer("inference"):
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            eta=0.0,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    return image