| import gradio as gr |
| import numpy as np |
| import random |
| import torch |
| from champ_flame_model import ChampFlameModel |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from diffusers import AutoencoderKL, DDPMScheduler |
| from unet2dConditionFineTune import UNet2DConditionModel |
| from mutual_self_attention import ReferenceAttentionControl |
| from guidance_encoder import GuidanceEncoder |
| from pipeline_stable_diffusion import StableDiffusionPipeline |
| from huggingface_hub import hf_hub_download |
|
|
| |
| MODEL_PATH = "ValerianFourel/RealisticEmotionStableDiffusion" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MAX_SEED = np.iinfo(np.int32).max |
|
|
| |
| def init_pipeline(): |
| print("Initializing pipeline...") |
|
|
| tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer") |
| text_encoder = CLIPTextModel.from_pretrained(MODEL_PATH, subfolder="text_encoder") |
| vae = AutoencoderKL.from_pretrained(MODEL_PATH, subfolder="vae") |
| reference_unet = UNet2DConditionModel.from_pretrained(MODEL_PATH, subfolder="unet") |
|
|
| |
| guidance_encoder_group = {} |
| guids = ["alignment", "depth", "flame"] |
|
|
| for guidance_type in guids: |
| guidance_encoder_group[guidance_type] = GuidanceEncoder( |
| guidance_embedding_channels=320, |
| guidance_input_channels=3, |
| block_out_channels=[16, 32, 96, 256] |
| ) |
|
|
| |
| state_dict_path = hf_hub_download( |
| repo_id=MODEL_PATH, |
| filename=f"guidance_encoder/{guidance_type}_encoder_pytorch_model.bin", |
| repo_type="model" |
| ) |
|
|
| state_dict = torch.load(state_dict_path, map_location=DEVICE) |
|
|
| if "module." in list(state_dict.keys())[0]: |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
| guidance_encoder_group[guidance_type].load_state_dict(state_dict) |
| guidance_encoder_group[guidance_type].to(DEVICE) |
| guidance_encoder_group[guidance_type].eval() |
|
|
| reference_control_writer = ReferenceAttentionControl( |
| reference_unet, |
| do_classifier_free_guidance=False, |
| mode="write", |
| fusion_blocks="full", |
| ) |
|
|
| model = ChampFlameModel( |
| reference_unet, |
| reference_control_writer, |
| guidance_encoder_group, |
| ) |
|
|
| |
| pipeline = StableDiffusionPipeline.from_pretrained( |
| MODEL_PATH, |
| text_encoder=text_encoder, |
| vae=vae, |
| unet=reference_unet, |
| safety_checker=None, |
| requires_safety_checker=False, |
| custom_pipeline=None, |
| use_safetensors=True, |
| local_files_only=False, |
| resume_download=True |
| ) |
|
|
| |
| pipeline = pipeline.to(DEVICE) |
|
|
| |
|
|
| return pipeline |
|
|
| |
| pipe = init_pipeline() |
|
|
| def clean_string(value): |
| return value.replace("blurred", "").replace("grainy", "").replace("blurry", "").replace("low-quality", "high-quality").strip() |
|
|
| def infer( |
| prompt, |
| emotion, |
| negative_prompt, |
| seed, |
| randomize_seed, |
| guidance_scale, |
| num_inference_steps, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
|
|
| generator = torch.Generator(device=DEVICE).manual_seed(seed) |
|
|
| |
| full_prompt = f"{emotion} , {clean_string(prompt)}" |
|
|
| with torch.no_grad(): |
| image = pipe( |
| prompt=full_prompt, |
| negative_prompt=negative_prompt, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator, |
| ).images[0] |
|
|
| return image, seed |
|
|
| examples = [ |
| ["A portrait of a young woman", "happiness"], |
| ["A close-up of a man's face", "anger"], |
| ["A professional headshot", "neutral"], |
| ] |
|
|
| emotions = ["Happy", "Sad", "Anger", "Fear", "Disgust", "Surprise", "Neutral","Contempt"] |
|
|
| css = """ |
| #col-container { |
| margin: 0 auto; |
| max-width: 640px; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="col-container"): |
| gr.Markdown("# Realistic Emotion Stable Diffusion") |
|
|
| with gr.Row(): |
| prompt = gr.Text( |
| label="Prompt", |
| show_label=False, |
| max_lines=1, |
| placeholder="Enter your prompt", |
| container=False, |
| ) |
|
|
| emotion = gr.Dropdown( |
| choices=emotions, |
| label="Emotion", |
| value="neutral", |
| container=False, |
| ) |
|
|
| run_button = gr.Button("Generate", scale=0, variant="primary") |
|
|
| result = gr.Image(label="Result", show_label=False) |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| negative_prompt = gr.Text( |
| label="Negative prompt", |
| max_lines=1, |
| value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", |
| ) |
|
|
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=MAX_SEED, |
| step=1, |
| value=0, |
| ) |
|
|
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
| with gr.Row(): |
| guidance_scale = gr.Slider( |
| label="Guidance scale", |
| minimum=1.0, |
| maximum=20.0, |
| step=0.5, |
| value=9.0, |
| ) |
|
|
| num_inference_steps = gr.Slider( |
| label="Number of inference steps", |
| minimum=50, |
| maximum=500, |
| step=10, |
| value=300, |
| ) |
|
|
| gr.Examples(examples=examples, inputs=[prompt, emotion]) |
|
|
| gr.on( |
| triggers=[run_button.click, prompt.submit], |
| fn=infer, |
| inputs=[ |
| prompt, |
| emotion, |
| negative_prompt, |
| seed, |
| randomize_seed, |
| guidance_scale, |
| num_inference_steps, |
| ], |
| outputs=[result, seed], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|