character-lab / app.py
harumaa's picture
Update app.py
a3a2469 verified
# -*- coding: utf-8 -*-
import os
import random
import torch
import numpy as np
import gradio as gr
import spaces
import threading
from PIL import Image
from pathlib import Path
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
from huggingface_hub import hf_hub_download
from transformers import CLIPVisionModelWithProjection
from compel import Compel, ReturnedEmbeddingsType
# --- Configuration ---
MODELS_DIR = Path("./models")
MODELS_DIR.mkdir(parents=True, exist_ok=True)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEFAULT_ACCENT = "#7c3aed"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1920
CIVITAI_MODELS = {
"waiNSFW Illustrious": {
"filename": "waiIllustriousSDXL_v160.safetensors",
"cfg": 6.0, "steps": 25, "version": "v16.0"
},
"wai Realmix": {
"filename": "waiREALMIX_v11.safetensors",
"cfg": 7.0, "steps": 30, "version": "v11"
}
}
MODEL_NAMES = list(CIVITAI_MODELS.keys())
SDXL_RATIOS = {
"1:1 Square": (1024, 1024),
"3:4 Portrait": (896, 1152),
"4:3 Landscape": (1152, 896),
"9:16 Mobile": (768, 1344),
"16:9 Desktop": (1344, 768)
}
ANIME_STUDIO_URL = "https://huggingface.co/spaces/harumaa/Anime-XL-Studio-And-Qwen-Edit"
QWEN_SPACE_URL = os.environ.get("QWEN_SPACE_URL", "https://huggingface.co/spaces/harumaa/Qwen-Image-Editor-SFW-NSFW")
UPSCALER_URL = os.environ.get("UPSCALER_URL", "https://huggingface.co/spaces/harumaa/AI-Image-Upscaler-4x")
_char_cache = {}
# --- Helper Functions ---
def fix_vae_dtype(vae, dtype=torch.float16):
vae = vae.to(dtype=dtype)
vae.config.force_upcast = False
return vae
def auto_res(input_image):
if input_image is None:
return gr.update(value=1024), gr.update(value=1024), gr.update(value="Resolution auto-detect ready"), gr.update(value=None)
try:
img = input_image if isinstance(input_image, Image.Image) else Image.open(input_image)
w = max(512, (min(img.width, MAX_IMAGE_SIZE) // 64) * 64)
h = max(512, (min(img.height, MAX_IMAGE_SIZE) // 64) * 64)
return gr.update(value=w), gr.update(value=h), gr.update(value=f"Auto-detected from image: {w} x {h}"), gr.update(value=None)
except Exception:
return gr.update(value=1024), gr.update(value=1024), gr.update(value="Error reading image"), gr.update(value=None)
def set_quick_ratio(ratio_label):
w, h = SDXL_RATIOS.get(ratio_label, (1024, 1024))
return gr.update(value=w), gr.update(value=h), gr.update(value=f"Resolution set by preset: {w} x {h}")
def encode_prompts(pipe, prompt, neg):
c1 = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=False)
c2 = Compel(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True)
cd1 = c1(prompt); cd2, pool = c2(prompt)
nc1 = c1(neg); nc2, npool = c2(neg)
[cd1, nc1] = c1.pad_conditioning_tensors_to_same_length([cd1, nc1])
[cd2, nc2] = c2.pad_conditioning_tensors_to_same_length([cd2, nc2])
expected_dim = pipe.unet.config.cross_attention_dim
if cd2.shape[-1] == expected_dim: return cd2, pool, nc2, npool
cond_cat = torch.cat([cd1, cd2], dim=-1)
ncond_cat = torch.cat([nc1, nc2], dim=-1)
if cond_cat.shape[-1] == expected_dim: return cond_cat, pool, ncond_cat, npool
if cond_cat.shape[-1] > expected_dim: return cond_cat[..., :expected_dim], pool, ncond_cat[..., :expected_dim], npool
pad = expected_dim - cond_cat.shape[-1]
return (torch.nn.functional.pad(cond_cat, (0, pad)), pool, torch.nn.functional.pad(ncond_cat, (0, pad)), npool)
# --- Preloading Logic ---
def download_and_load_all():
print("🚀 [Startup] Initializing Character Studio pre-load...")
print("⚙️ [Pipeline] Loading ViT-H Image Encoder...")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16
)
for model_name in CIVITAI_MODELS.keys():
try:
info = CIVITAI_MODELS[model_name]
dest = MODELS_DIR / info["filename"]
if not dest.exists():
hf_hub_download(repo_id="harumaa/private-models", filename=info["filename"], local_dir=str(MODELS_DIR), token=HF_TOKEN or None)
if model_name not in _char_cache:
print(f"⚙️ [Pipeline] Building {model_name} with IP-Adapter Plus...")
pipe = StableDiffusionXLPipeline.from_single_file(
str(dest), torch_dtype=torch.float16, use_safetensors=True, image_encoder=image_encoder
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# Loading the High-Detail PLUS model
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.safetensors"
)
_char_cache[model_name] = pipe
except Exception as e:
print(f"❌ [Error] Failed to load {model_name}: {e}")
print("✅ [Startup] All models pre-downloaded and pipelines ready.")
# --- Inference ---
@spaces.GPU(duration=60)
def generate_gpu(model_name, ref_img, prompt, neg, seed, randomize, w, h, ip_scale, cfg, steps, progress=gr.Progress(track_tqdm=False)):
if ref_img is None:
yield None, seed, gr.update(value="⚠️ Error: Please upload a Character Reference image!"), gr.update(visible=False)
return
if randomize: seed = random.randint(0, MAX_SEED)
yield None, seed, gr.update(value="Waiting for GPU..."), gr.update(visible=False, value=None)
pipe = _char_cache.get(model_name)
if pipe is None:
yield None, seed, gr.update(value=f"Model {model_name} not loaded."), gr.update(visible=False)
return
try:
# GPU Prep, VAE Lock, and Watermark removal (The Anti-Deep-Fry Armor)
pipe.to("cuda")
pipe.vae = fix_vae_dtype(pipe.vae)
pipe.watermark = None
pipe.set_ip_adapter_scale(float(ip_scale))
generator = torch.Generator("cuda").manual_seed(seed)
final_prompt = prompt or ""
if "masterpiece" not in final_prompt.lower():
final_prompt = f"{final_prompt}, 1girl, solo, masterpiece, best quality, very aesthetic, absurdres"
base_neg = "lowres, (bad), text, error, missing, extra, worst quality, jpeg artifacts, bad anatomy, watermark, unfinished, displeasing, oldest, early, signature, extra digits, fewer digits, bad hands"
final_neg = f"{neg}, {base_neg}" if neg else base_neg
yield None, seed, gr.update(value=f"Generating {w}x{h}..."), gr.update(visible=False)
def cb(p, i, t, kw): progress((i+1)/steps, desc=f"Step {i+1}/{steps}"); return kw
cond, pool, ncond, npool = encode_prompts(pipe, final_prompt, final_neg)
image = pipe(
prompt_embeds=cond,
pooled_prompt_embeds=pool,
negative_prompt_embeds=ncond,
negative_pooled_prompt_embeds=npool,
ip_adapter_image=ref_img,
guidance_scale=cfg,
num_inference_steps=steps,
width=int(w),
height=int(h),
generator=generator,
callback_on_step_end=cb
).images[0]
yield image, seed, gr.update(value="✅ Done!"), gr.update(visible=True)
except Exception as e:
import traceback
traceback.print_exc()
torch.cuda.empty_cache()
yield None, seed, gr.update(value=f"❌ Crash: {str(e)[:150]}"), gr.update(visible=False)
# --- UI Setup ---
def build_css(accent="#7c3aed"):
return f"""
:root{{--accent:{accent};--accent-dim:color-mix(in srgb,{accent} 20%,transparent);--accent-dark:color-mix(in srgb,{accent} 72%,black);--accent-soft:color-mix(in srgb,{accent} 10%,transparent);}}
#col-container{{margin:0 auto;max-width:980px;}}
footer{{display:none!important;}}
.lbl{{font-weight:700;font-size:11px;letter-spacing:.07em;text-transform:uppercase;color:var(--accent);margin:10px 0 3px 2px;}}
.suite-banner{{display:flex;gap:10px;margin:12px 0 10px 0;flex-wrap:wrap;}}
.suite-banner a{{flex:1;min-width:150px;text-align:center;background:var(--accent-soft);border:1.5px solid color-mix(in srgb,{accent} 35%,transparent);color:var(--accent)!important;font-weight:700;font-size:13px;padding:10px 14px;border-radius:10px;text-decoration:none!important;transition:all .18s ease;white-space:nowrap;}}
.suite-banner a:hover{{background:var(--accent);color:#fff!important;transform:translateY(-1px);box-shadow:0 0 14px var(--accent-dim);}}
.suite-banner a.active{{background:var(--accent);color:#fff!important;box-shadow:0 0 16px var(--accent-dim);pointer-events:none;}}
.res-status{{font-size:12px;font-weight:600;color:var(--accent);padding:4px 10px;border-radius:6px;background:var(--accent-dim);margin-bottom:8px;min-height:26px;}}
button.primary{{background:var(--accent)!important;border-color:var(--accent)!important;transition:all .18s ease;}}
button.primary:hover{{background:var(--accent-dark)!important;box-shadow:0 0 14px var(--accent-dim)!important;transform:translateY(-1px);}}
.model-btn{{font-size:14px!important;padding:12px 8px!important;border-radius:12px!important;font-weight:600!important;transition:all .18s ease!important;border:2px solid transparent!important;}}
.model-btn:hover{{border-color:var(--accent)!important;transform:translateY(-2px)!important;box-shadow:0 4px 16px var(--accent-dim)!important;}}
.stop-btn{{background:#c0392b!important;border-color:#c0392b!important;color:#fff!important;}}
#ch-out{{border:2px solid var(--accent)!important;border-radius:12px!important;box-shadow:0 0 20px var(--accent-dim);}}
.gen-status{{font-size:13px;font-weight:600;padding:6px 10px;border-radius:8px;background:var(--accent-dim);color:var(--accent);margin-bottom:6px;min-height:32px;}}
"""
def make_model_selector(idx):
name = MODEL_NAMES[idx]
def _fn(): return (name, gr.update(value=CIVITAI_MODELS[name]["cfg"]), gr.update(value=CIVITAI_MODELS[name]["steps"]))
return _fn
with gr.Blocks(title="Character Lab", css=build_css(DEFAULT_ACCENT), theme=gr.themes.Soft(primary_hue="purple")) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(f'''
<div class="suite-banner">
<a href="{ANIME_STUDIO_URL}">Anime Studio</a>
<a class="active" href="#">Character Lab</a>
<a href="{QWEN_SPACE_URL}" target="_blank">Qwen Editor</a>
<a href="{UPSCALER_URL}" target="_blank">Upscaler</a>
</div>
''')
gr.Markdown("# Character Lab\n**High-detail character consistency from a single image.**")
with gr.Row():
ch_btn_m0 = gr.Button(f"{MODEL_NAMES[0]}", variant="primary", scale=1, elem_classes="model-btn")
ch_btn_m1 = gr.Button(f"{MODEL_NAMES[1]}", variant="secondary", scale=1, elem_classes="model-btn")
ch_model = gr.Textbox(value=MODEL_NAMES[0], visible=False)
with gr.Row():
with gr.Column(scale=1):
ch_input = gr.Image(label="Full Character Reference", type="pil", height=380)
ch_res_status = gr.Markdown("Resolution auto-detect ready", elem_classes="res-status")
with gr.Row():
ch_w = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
ch_h = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
with gr.Column(scale=1):
ch_result = gr.Image(show_label=False, elem_id="ch-out", height=500)
ch_dl_btn = gr.DownloadButton(label="Download", size="sm", visible=False, elem_classes="dl-btn")
with gr.Row():
ch_prompt = gr.Text(show_label=False, max_lines=3, container=False, placeholder="Describe the character, outfit, and scene...")
ch_run = gr.Button("Generate", variant="primary", scale=0)
ch_stop = gr.Button("Cancel", size="sm", scale=0, elem_classes="stop-btn")
with gr.Row():
ch_quick_ratio = gr.Radio(choices=list(SDXL_RATIOS.keys()), label="Quick Pick Resolution Presets", container=True)
ch_gen_status = gr.Markdown("Models pre-loading in background...", elem_classes="gen-status")
with gr.Accordion("Advanced Settings", open=False):
ch_neg = gr.Text(label="Negative Prompt", max_lines=2, value="lowres, bad anatomy, bad hands, text, watermark, worst quality")
# The Sweet Spot for IP-Adapter Plus on Anime models
ch_ip_scale = gr.Slider(label="Consistency Strength (IP-Scale)", minimum=0.1, maximum=1.0, step=0.05, value=0.2)
with gr.Row():
ch_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
ch_rand = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
ch_cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=7.0)
ch_steps = gr.Slider(label="Steps", minimum=10, maximum=60, step=1, value=30)
for i, btn in enumerate([ch_btn_m0, ch_btn_m1]):
btn.click(fn=make_model_selector(i), inputs=[], outputs=[ch_model, ch_cfg, ch_steps])
ch_input.change(fn=auto_res, inputs=[ch_input], outputs=[ch_w, ch_h, ch_res_status, ch_quick_ratio])
ch_quick_ratio.change(fn=set_quick_ratio, inputs=[ch_quick_ratio], outputs=[ch_w, ch_h, ch_res_status])
ch_event = gr.on(
triggers=[ch_run.click, ch_prompt.submit],
fn=generate_gpu,
inputs=[ch_model, ch_input, ch_prompt, ch_neg, ch_seed, ch_rand, ch_w, ch_h, ch_ip_scale, ch_cfg, ch_steps],
outputs=[ch_result, ch_seed, ch_gen_status, ch_dl_btn]
)
ch_stop.click(fn=None, cancels=[ch_event], queue=False)
threading.Thread(target=download_and_load_all, daemon=True).start()
if __name__ == "__main__":
demo.queue(max_size=20, default_concurrency_limit=1).launch(ssr_mode=False)