Faice_text2face / inference_models /class_guidance_inference.py
Ngene787's picture
refactor: change structure
4d24ceb
# -*- coding: UTF-8 -*-
"""
@Time : 30/05/2025 19:24
@Author : xiaoguangliang
@File : class_guidance_inference.py
@Project : Faice_text2face
"""
import torch
import random
import numpy as np
from inference_models.ccddpm_pipeline import CCDDPMPipeline
from accelerate import Accelerator
import gradio as gr
import spaces
from loguru import logger
from utils import timer
model_path = 'Ngene787/Faice_class_guidance'
if torch.backends.mps.is_available():
accelerator = Accelerator(gradient_accumulation_steps=1)
else:
accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)
logger.info("Loading model ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = CCDDPMPipeline.from_pretrained(model_path, torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
pipe = pipe.to(device)
pipe = accelerator.prepare(pipe)
# Enable memory-efficient attention
# pipe.enable_xformers_memory_efficient_attention()
MAX_SEED = np.iinfo(np.int32).max
GENDER_CHOICES = [
"Female",
"Male"
]
@spaces.GPU(duration=65)
def inference_class_guidance(label_name,
seed=0,
randomize_seed=False,
num_inference_steps=20,
progress=gr.Progress(track_tqdm=True), ):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
label_id = 1 if label_name == "Male" else 0
logger.info('Generating image ...')
batch_size = 1
with timer("inference"):
class_labels = torch.full(
(batch_size,), label_id, dtype=torch.long, device=device
)
encoder_hidden_states = torch.zeros(
batch_size,
1,
pipe.unet.config.cross_attention_dim,
device=device,
)
image = pipe(
batch_size=batch_size,
generator=generator,
num_inference_steps=num_inference_steps,
class_labels=class_labels,
encoder_hidden_states=encoder_hidden_states,
).images[0]
return image