import gradio as gr import numpy as np import random import torch from champ_flame_model import ChampFlameModel from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDPMScheduler from unet2dConditionFineTune import UNet2DConditionModel from mutual_self_attention import ReferenceAttentionControl from guidance_encoder import GuidanceEncoder from pipeline_stable_diffusion import StableDiffusionPipeline from huggingface_hub import hf_hub_download # Global constants MODEL_PATH = "ValerianFourel/RealisticEmotionStableDiffusion" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max # Initialize pipeline once at startup def init_pipeline(): print("Initializing pipeline...") tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(MODEL_PATH, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(MODEL_PATH, subfolder="vae") reference_unet = UNet2DConditionModel.from_pretrained(MODEL_PATH, subfolder="unet") # Initialize guidance encoders guidance_encoder_group = {} guids = ["alignment", "depth", "flame"] for guidance_type in guids: guidance_encoder_group[guidance_type] = GuidanceEncoder( guidance_embedding_channels=320, guidance_input_channels=3, block_out_channels=[16, 32, 96, 256] ) # Download the file from the hub using hf_hub_download state_dict_path = hf_hub_download( repo_id=MODEL_PATH, filename=f"guidance_encoder/{guidance_type}_encoder_pytorch_model.bin", repo_type="model" ) state_dict = torch.load(state_dict_path, map_location=DEVICE) if "module." in list(state_dict.keys())[0]: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} guidance_encoder_group[guidance_type].load_state_dict(state_dict) guidance_encoder_group[guidance_type].to(DEVICE) guidance_encoder_group[guidance_type].eval() reference_control_writer = ReferenceAttentionControl( reference_unet, do_classifier_free_guidance=False, mode="write", fusion_blocks="full", ) model = ChampFlameModel( reference_unet, reference_control_writer, guidance_encoder_group, ) # Load the pipeline with components pipeline = StableDiffusionPipeline.from_pretrained( MODEL_PATH, text_encoder=text_encoder, vae=vae, unet=reference_unet, safety_checker=None, # Optional: disable safety checker if not needed requires_safety_checker=False, custom_pipeline=None, use_safetensors=True, # Add this to support safetensors format local_files_only=False, resume_download=True # Add this to resume interrupted downloads ) # Move to device pipeline = pipeline.to(DEVICE) return pipeline # Initialize pipeline globally pipe = init_pipeline() def clean_string(value): return value.replace("blurred", "").replace("grainy", "").replace("blurry", "").replace("low-quality", "high-quality").strip() def infer( prompt, emotion, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=DEVICE).manual_seed(seed) # Combine prompt with emotion full_prompt = f"{emotion} , {clean_string(prompt)}" with torch.no_grad(): image = pipe( prompt=full_prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, ).images[0] return image, seed examples = [ ["A portrait of a young woman", "happiness"], ["A close-up of a man's face", "anger"], ["A professional headshot", "neutral"], ] emotions = ["Happy", "Sad", "Anger", "Fear", "Disgust", "Surprise", "Neutral","Contempt"] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Realistic Emotion Stable Diffusion") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) emotion = gr.Dropdown( choices=emotions, label="Emotion", value="neutral", container=False, ) run_button = gr.Button("Generate", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=20.0, step=0.5, value=9.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=50, maximum=500, step=10, value=300, ) gr.Examples(examples=examples, inputs=[prompt, emotion]) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, emotion, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, ], outputs=[result, seed], ) if __name__ == "__main__": demo.launch()