ValerianFourel
add Majuscule
4e3eb3d
import gradio as gr
import numpy as np
import random
import torch
from champ_flame_model import ChampFlameModel
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler
from unet2dConditionFineTune import UNet2DConditionModel
from mutual_self_attention import ReferenceAttentionControl
from guidance_encoder import GuidanceEncoder
from pipeline_stable_diffusion import StableDiffusionPipeline
from huggingface_hub import hf_hub_download
# Global constants
MODEL_PATH = "ValerianFourel/RealisticEmotionStableDiffusion"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
# Initialize pipeline once at startup
def init_pipeline():
print("Initializing pipeline...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_PATH, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_PATH, subfolder="vae")
reference_unet = UNet2DConditionModel.from_pretrained(MODEL_PATH, subfolder="unet")
# Initialize guidance encoders
guidance_encoder_group = {}
guids = ["alignment", "depth", "flame"]
for guidance_type in guids:
guidance_encoder_group[guidance_type] = GuidanceEncoder(
guidance_embedding_channels=320,
guidance_input_channels=3,
block_out_channels=[16, 32, 96, 256]
)
# Download the file from the hub using hf_hub_download
state_dict_path = hf_hub_download(
repo_id=MODEL_PATH,
filename=f"guidance_encoder/{guidance_type}_encoder_pytorch_model.bin",
repo_type="model"
)
state_dict = torch.load(state_dict_path, map_location=DEVICE)
if "module." in list(state_dict.keys())[0]:
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
guidance_encoder_group[guidance_type].load_state_dict(state_dict)
guidance_encoder_group[guidance_type].to(DEVICE)
guidance_encoder_group[guidance_type].eval()
reference_control_writer = ReferenceAttentionControl(
reference_unet,
do_classifier_free_guidance=False,
mode="write",
fusion_blocks="full",
)
model = ChampFlameModel(
reference_unet,
reference_control_writer,
guidance_encoder_group,
)
# Load the pipeline with components
pipeline = StableDiffusionPipeline.from_pretrained(
MODEL_PATH,
text_encoder=text_encoder,
vae=vae,
unet=reference_unet,
safety_checker=None, # Optional: disable safety checker if not needed
requires_safety_checker=False,
custom_pipeline=None,
use_safetensors=True, # Add this to support safetensors format
local_files_only=False,
resume_download=True # Add this to resume interrupted downloads
)
# Move to device
pipeline = pipeline.to(DEVICE)
return pipeline
# Initialize pipeline globally
pipe = init_pipeline()
def clean_string(value):
return value.replace("blurred", "").replace("grainy", "").replace("blurry", "").replace("low-quality", "high-quality").strip()
def infer(
prompt,
emotion,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
# Combine prompt with emotion
full_prompt = f"{emotion} , {clean_string(prompt)}"
with torch.no_grad():
image = pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return image, seed
examples = [
["A portrait of a young woman", "happiness"],
["A close-up of a man's face", "anger"],
["A professional headshot", "neutral"],
]
emotions = ["Happy", "Sad", "Anger", "Fear", "Disgust", "Surprise", "Neutral","Contempt"]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Realistic Emotion Stable Diffusion")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
emotion = gr.Dropdown(
choices=emotions,
label="Emotion",
value="neutral",
container=False,
)
run_button = gr.Button("Generate", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=20.0,
step=0.5,
value=9.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=50,
maximum=500,
step=10,
value=300,
)
gr.Examples(examples=examples, inputs=[prompt, emotion])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
emotion,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()