Spaces:
Running on Zero
Running on Zero
| # -*- coding: utf-8 -*- | |
| import os | |
| import random | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| import threading | |
| from PIL import Image | |
| from pathlib import Path | |
| from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
| from huggingface_hub import hf_hub_download | |
| from transformers import CLIPVisionModelWithProjection | |
| from compel import Compel, ReturnedEmbeddingsType | |
| # --- Configuration --- | |
| MODELS_DIR = Path("./models") | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DEFAULT_ACCENT = "#7c3aed" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1920 | |
| CIVITAI_MODELS = { | |
| "waiNSFW Illustrious": { | |
| "filename": "waiIllustriousSDXL_v160.safetensors", | |
| "cfg": 6.0, "steps": 25, "version": "v16.0" | |
| }, | |
| "wai Realmix": { | |
| "filename": "waiREALMIX_v11.safetensors", | |
| "cfg": 7.0, "steps": 30, "version": "v11" | |
| } | |
| } | |
| MODEL_NAMES = list(CIVITAI_MODELS.keys()) | |
| SDXL_RATIOS = { | |
| "1:1 Square": (1024, 1024), | |
| "3:4 Portrait": (896, 1152), | |
| "4:3 Landscape": (1152, 896), | |
| "9:16 Mobile": (768, 1344), | |
| "16:9 Desktop": (1344, 768) | |
| } | |
| ANIME_STUDIO_URL = "https://huggingface.co/spaces/harumaa/Anime-XL-Studio-And-Qwen-Edit" | |
| QWEN_SPACE_URL = os.environ.get("QWEN_SPACE_URL", "https://huggingface.co/spaces/harumaa/Qwen-Image-Editor-SFW-NSFW") | |
| UPSCALER_URL = os.environ.get("UPSCALER_URL", "https://huggingface.co/spaces/harumaa/AI-Image-Upscaler-4x") | |
| _char_cache = {} | |
| # --- Helper Functions --- | |
| def fix_vae_dtype(vae, dtype=torch.float16): | |
| vae = vae.to(dtype=dtype) | |
| vae.config.force_upcast = False | |
| return vae | |
| def auto_res(input_image): | |
| if input_image is None: | |
| return gr.update(value=1024), gr.update(value=1024), gr.update(value="Resolution auto-detect ready"), gr.update(value=None) | |
| try: | |
| img = input_image if isinstance(input_image, Image.Image) else Image.open(input_image) | |
| w = max(512, (min(img.width, MAX_IMAGE_SIZE) // 64) * 64) | |
| h = max(512, (min(img.height, MAX_IMAGE_SIZE) // 64) * 64) | |
| return gr.update(value=w), gr.update(value=h), gr.update(value=f"Auto-detected from image: {w} x {h}"), gr.update(value=None) | |
| except Exception: | |
| return gr.update(value=1024), gr.update(value=1024), gr.update(value="Error reading image"), gr.update(value=None) | |
| def set_quick_ratio(ratio_label): | |
| w, h = SDXL_RATIOS.get(ratio_label, (1024, 1024)) | |
| return gr.update(value=w), gr.update(value=h), gr.update(value=f"Resolution set by preset: {w} x {h}") | |
| def encode_prompts(pipe, prompt, neg): | |
| c1 = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=False) | |
| c2 = Compel(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True) | |
| cd1 = c1(prompt); cd2, pool = c2(prompt) | |
| nc1 = c1(neg); nc2, npool = c2(neg) | |
| [cd1, nc1] = c1.pad_conditioning_tensors_to_same_length([cd1, nc1]) | |
| [cd2, nc2] = c2.pad_conditioning_tensors_to_same_length([cd2, nc2]) | |
| expected_dim = pipe.unet.config.cross_attention_dim | |
| if cd2.shape[-1] == expected_dim: return cd2, pool, nc2, npool | |
| cond_cat = torch.cat([cd1, cd2], dim=-1) | |
| ncond_cat = torch.cat([nc1, nc2], dim=-1) | |
| if cond_cat.shape[-1] == expected_dim: return cond_cat, pool, ncond_cat, npool | |
| if cond_cat.shape[-1] > expected_dim: return cond_cat[..., :expected_dim], pool, ncond_cat[..., :expected_dim], npool | |
| pad = expected_dim - cond_cat.shape[-1] | |
| return (torch.nn.functional.pad(cond_cat, (0, pad)), pool, torch.nn.functional.pad(ncond_cat, (0, pad)), npool) | |
| # --- Preloading Logic --- | |
| def download_and_load_all(): | |
| print("🚀 [Startup] Initializing Character Studio pre-load...") | |
| print("⚙️ [Pipeline] Loading ViT-H Image Encoder...") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| "h94/IP-Adapter", | |
| subfolder="models/image_encoder", | |
| torch_dtype=torch.float16 | |
| ) | |
| for model_name in CIVITAI_MODELS.keys(): | |
| try: | |
| info = CIVITAI_MODELS[model_name] | |
| dest = MODELS_DIR / info["filename"] | |
| if not dest.exists(): | |
| hf_hub_download(repo_id="harumaa/private-models", filename=info["filename"], local_dir=str(MODELS_DIR), token=HF_TOKEN or None) | |
| if model_name not in _char_cache: | |
| print(f"⚙️ [Pipeline] Building {model_name} with IP-Adapter Plus...") | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| str(dest), torch_dtype=torch.float16, use_safetensors=True, image_encoder=image_encoder | |
| ) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| # Loading the High-Detail PLUS model | |
| pipe.load_ip_adapter( | |
| "h94/IP-Adapter", | |
| subfolder="sdxl_models", | |
| weight_name="ip-adapter-plus_sdxl_vit-h.safetensors" | |
| ) | |
| _char_cache[model_name] = pipe | |
| except Exception as e: | |
| print(f"❌ [Error] Failed to load {model_name}: {e}") | |
| print("✅ [Startup] All models pre-downloaded and pipelines ready.") | |
| # --- Inference --- | |
| def generate_gpu(model_name, ref_img, prompt, neg, seed, randomize, w, h, ip_scale, cfg, steps, progress=gr.Progress(track_tqdm=False)): | |
| if ref_img is None: | |
| yield None, seed, gr.update(value="⚠️ Error: Please upload a Character Reference image!"), gr.update(visible=False) | |
| return | |
| if randomize: seed = random.randint(0, MAX_SEED) | |
| yield None, seed, gr.update(value="Waiting for GPU..."), gr.update(visible=False, value=None) | |
| pipe = _char_cache.get(model_name) | |
| if pipe is None: | |
| yield None, seed, gr.update(value=f"Model {model_name} not loaded."), gr.update(visible=False) | |
| return | |
| try: | |
| # GPU Prep, VAE Lock, and Watermark removal (The Anti-Deep-Fry Armor) | |
| pipe.to("cuda") | |
| pipe.vae = fix_vae_dtype(pipe.vae) | |
| pipe.watermark = None | |
| pipe.set_ip_adapter_scale(float(ip_scale)) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| final_prompt = prompt or "" | |
| if "masterpiece" not in final_prompt.lower(): | |
| final_prompt = f"{final_prompt}, 1girl, solo, masterpiece, best quality, very aesthetic, absurdres" | |
| base_neg = "lowres, (bad), text, error, missing, extra, worst quality, jpeg artifacts, bad anatomy, watermark, unfinished, displeasing, oldest, early, signature, extra digits, fewer digits, bad hands" | |
| final_neg = f"{neg}, {base_neg}" if neg else base_neg | |
| yield None, seed, gr.update(value=f"Generating {w}x{h}..."), gr.update(visible=False) | |
| def cb(p, i, t, kw): progress((i+1)/steps, desc=f"Step {i+1}/{steps}"); return kw | |
| cond, pool, ncond, npool = encode_prompts(pipe, final_prompt, final_neg) | |
| image = pipe( | |
| prompt_embeds=cond, | |
| pooled_prompt_embeds=pool, | |
| negative_prompt_embeds=ncond, | |
| negative_pooled_prompt_embeds=npool, | |
| ip_adapter_image=ref_img, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| width=int(w), | |
| height=int(h), | |
| generator=generator, | |
| callback_on_step_end=cb | |
| ).images[0] | |
| yield image, seed, gr.update(value="✅ Done!"), gr.update(visible=True) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| torch.cuda.empty_cache() | |
| yield None, seed, gr.update(value=f"❌ Crash: {str(e)[:150]}"), gr.update(visible=False) | |
| # --- UI Setup --- | |
| def build_css(accent="#7c3aed"): | |
| return f""" | |
| :root{{--accent:{accent};--accent-dim:color-mix(in srgb,{accent} 20%,transparent);--accent-dark:color-mix(in srgb,{accent} 72%,black);--accent-soft:color-mix(in srgb,{accent} 10%,transparent);}} | |
| #col-container{{margin:0 auto;max-width:980px;}} | |
| footer{{display:none!important;}} | |
| .lbl{{font-weight:700;font-size:11px;letter-spacing:.07em;text-transform:uppercase;color:var(--accent);margin:10px 0 3px 2px;}} | |
| .suite-banner{{display:flex;gap:10px;margin:12px 0 10px 0;flex-wrap:wrap;}} | |
| .suite-banner a{{flex:1;min-width:150px;text-align:center;background:var(--accent-soft);border:1.5px solid color-mix(in srgb,{accent} 35%,transparent);color:var(--accent)!important;font-weight:700;font-size:13px;padding:10px 14px;border-radius:10px;text-decoration:none!important;transition:all .18s ease;white-space:nowrap;}} | |
| .suite-banner a:hover{{background:var(--accent);color:#fff!important;transform:translateY(-1px);box-shadow:0 0 14px var(--accent-dim);}} | |
| .suite-banner a.active{{background:var(--accent);color:#fff!important;box-shadow:0 0 16px var(--accent-dim);pointer-events:none;}} | |
| .res-status{{font-size:12px;font-weight:600;color:var(--accent);padding:4px 10px;border-radius:6px;background:var(--accent-dim);margin-bottom:8px;min-height:26px;}} | |
| button.primary{{background:var(--accent)!important;border-color:var(--accent)!important;transition:all .18s ease;}} | |
| button.primary:hover{{background:var(--accent-dark)!important;box-shadow:0 0 14px var(--accent-dim)!important;transform:translateY(-1px);}} | |
| .model-btn{{font-size:14px!important;padding:12px 8px!important;border-radius:12px!important;font-weight:600!important;transition:all .18s ease!important;border:2px solid transparent!important;}} | |
| .model-btn:hover{{border-color:var(--accent)!important;transform:translateY(-2px)!important;box-shadow:0 4px 16px var(--accent-dim)!important;}} | |
| .stop-btn{{background:#c0392b!important;border-color:#c0392b!important;color:#fff!important;}} | |
| #ch-out{{border:2px solid var(--accent)!important;border-radius:12px!important;box-shadow:0 0 20px var(--accent-dim);}} | |
| .gen-status{{font-size:13px;font-weight:600;padding:6px 10px;border-radius:8px;background:var(--accent-dim);color:var(--accent);margin-bottom:6px;min-height:32px;}} | |
| """ | |
| def make_model_selector(idx): | |
| name = MODEL_NAMES[idx] | |
| def _fn(): return (name, gr.update(value=CIVITAI_MODELS[name]["cfg"]), gr.update(value=CIVITAI_MODELS[name]["steps"])) | |
| return _fn | |
| with gr.Blocks(title="Character Lab", css=build_css(DEFAULT_ACCENT), theme=gr.themes.Soft(primary_hue="purple")) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(f''' | |
| <div class="suite-banner"> | |
| <a href="{ANIME_STUDIO_URL}">Anime Studio</a> | |
| <a class="active" href="#">Character Lab</a> | |
| <a href="{QWEN_SPACE_URL}" target="_blank">Qwen Editor</a> | |
| <a href="{UPSCALER_URL}" target="_blank">Upscaler</a> | |
| </div> | |
| ''') | |
| gr.Markdown("# Character Lab\n**High-detail character consistency from a single image.**") | |
| with gr.Row(): | |
| ch_btn_m0 = gr.Button(f"{MODEL_NAMES[0]}", variant="primary", scale=1, elem_classes="model-btn") | |
| ch_btn_m1 = gr.Button(f"{MODEL_NAMES[1]}", variant="secondary", scale=1, elem_classes="model-btn") | |
| ch_model = gr.Textbox(value=MODEL_NAMES[0], visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ch_input = gr.Image(label="Full Character Reference", type="pil", height=380) | |
| ch_res_status = gr.Markdown("Resolution auto-detect ready", elem_classes="res-status") | |
| with gr.Row(): | |
| ch_w = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Width") | |
| ch_h = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Height") | |
| with gr.Column(scale=1): | |
| ch_result = gr.Image(show_label=False, elem_id="ch-out", height=500) | |
| ch_dl_btn = gr.DownloadButton(label="Download", size="sm", visible=False, elem_classes="dl-btn") | |
| with gr.Row(): | |
| ch_prompt = gr.Text(show_label=False, max_lines=3, container=False, placeholder="Describe the character, outfit, and scene...") | |
| ch_run = gr.Button("Generate", variant="primary", scale=0) | |
| ch_stop = gr.Button("Cancel", size="sm", scale=0, elem_classes="stop-btn") | |
| with gr.Row(): | |
| ch_quick_ratio = gr.Radio(choices=list(SDXL_RATIOS.keys()), label="Quick Pick Resolution Presets", container=True) | |
| ch_gen_status = gr.Markdown("Models pre-loading in background...", elem_classes="gen-status") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| ch_neg = gr.Text(label="Negative Prompt", max_lines=2, value="lowres, bad anatomy, bad hands, text, watermark, worst quality") | |
| # The Sweet Spot for IP-Adapter Plus on Anime models | |
| ch_ip_scale = gr.Slider(label="Consistency Strength (IP-Scale)", minimum=0.1, maximum=1.0, step=0.05, value=0.2) | |
| with gr.Row(): | |
| ch_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| ch_rand = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| ch_cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=7.0) | |
| ch_steps = gr.Slider(label="Steps", minimum=10, maximum=60, step=1, value=30) | |
| for i, btn in enumerate([ch_btn_m0, ch_btn_m1]): | |
| btn.click(fn=make_model_selector(i), inputs=[], outputs=[ch_model, ch_cfg, ch_steps]) | |
| ch_input.change(fn=auto_res, inputs=[ch_input], outputs=[ch_w, ch_h, ch_res_status, ch_quick_ratio]) | |
| ch_quick_ratio.change(fn=set_quick_ratio, inputs=[ch_quick_ratio], outputs=[ch_w, ch_h, ch_res_status]) | |
| ch_event = gr.on( | |
| triggers=[ch_run.click, ch_prompt.submit], | |
| fn=generate_gpu, | |
| inputs=[ch_model, ch_input, ch_prompt, ch_neg, ch_seed, ch_rand, ch_w, ch_h, ch_ip_scale, ch_cfg, ch_steps], | |
| outputs=[ch_result, ch_seed, ch_gen_status, ch_dl_btn] | |
| ) | |
| ch_stop.click(fn=None, cancels=[ch_event], queue=False) | |
| threading.Thread(target=download_and_load_all, daemon=True).start() | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20, default_concurrency_limit=1).launch(ssr_mode=False) |