# -*- coding: UTF-8 -*- """ @Time : 30/05/2025 19:24 @Author : xiaoguangliang @File : class_guidance_inference.py @Project : Faice_text2face """ import torch import random import numpy as np from inference_models.ccddpm_pipeline import CCDDPMPipeline from accelerate import Accelerator import gradio as gr import spaces from loguru import logger from utils import timer model_path = 'Ngene787/Faice_class_guidance' if torch.backends.mps.is_available(): accelerator = Accelerator(gradient_accumulation_steps=1) else: accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1) logger.info("Loading model ...") device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 pipe = CCDDPMPipeline.from_pretrained(model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True ) pipe = pipe.to(device) pipe = accelerator.prepare(pipe) # Enable memory-efficient attention # pipe.enable_xformers_memory_efficient_attention() MAX_SEED = np.iinfo(np.int32).max GENDER_CHOICES = [ "Female", "Male" ] @spaces.GPU(duration=65) def inference_class_guidance(label_name, seed=0, randomize_seed=False, num_inference_steps=20, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) label_id = 1 if label_name == "Male" else 0 logger.info('Generating image ...') batch_size = 1 with timer("inference"): class_labels = torch.full( (batch_size,), label_id, dtype=torch.long, device=device ) encoder_hidden_states = torch.zeros( batch_size, 1, pipe.unet.config.cross_attention_dim, device=device, ) image = pipe( batch_size=batch_size, generator=generator, num_inference_steps=num_inference_steps, class_labels=class_labels, encoder_hidden_states=encoder_hidden_states, ).images[0] return image