Faice_text2face / inference_models /stable_diffusion_inference.py
Ngene787's picture
refactor: change structure
4d24ceb
# -*- coding: UTF-8 -*-
"""
@Time : 28/05/2025 11:56
@Author : xiaoguangliang
@File : stable_diffusion_inference.py
@Project : Faice_text2face
"""
import torch
import random
import numpy as np
from diffusers import StableDiffusionPipeline
from accelerate import Accelerator
import gradio as gr
import spaces
from loguru import logger
from utils import timer
model_path = 'Ngene787/Faice_text2face'
if torch.backends.mps.is_available():
accelerator = Accelerator(gradient_accumulation_steps=1)
else:
accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)
logger.info("Loading model ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
# requires_safety_checker=False
)
pipe = pipe.to(device)
pipe = accelerator.prepare(pipe)
# Enable memory-efficient attention
# pipe.enable_xformers_memory_efficient_attention()
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU(duration=65)
def inference_sd(prompt,
negative_prompt="",
seed=0,
randomize_seed=False,
guidance_scale=7.5,
num_inference_steps=20,
progress=gr.Progress(track_tqdm=True), ):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
logger.info('Generating image ...')
with timer("inference"):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
eta=0.0,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return image