# -*- coding: utf-8 -*- import os import random import torch import numpy as np import gradio as gr import spaces import threading from PIL import Image from pathlib import Path from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler from huggingface_hub import hf_hub_download from transformers import CLIPVisionModelWithProjection from compel import Compel, ReturnedEmbeddingsType # --- Configuration --- MODELS_DIR = Path("./models") MODELS_DIR.mkdir(parents=True, exist_ok=True) HF_TOKEN = os.environ.get("HF_TOKEN", "") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_ACCENT = "#7c3aed" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1920 CIVITAI_MODELS = { "waiNSFW Illustrious": { "filename": "waiIllustriousSDXL_v160.safetensors", "cfg": 6.0, "steps": 25, "version": "v16.0" }, "wai Realmix": { "filename": "waiREALMIX_v11.safetensors", "cfg": 7.0, "steps": 30, "version": "v11" } } MODEL_NAMES = list(CIVITAI_MODELS.keys()) SDXL_RATIOS = { "1:1 Square": (1024, 1024), "3:4 Portrait": (896, 1152), "4:3 Landscape": (1152, 896), "9:16 Mobile": (768, 1344), "16:9 Desktop": (1344, 768) } ANIME_STUDIO_URL = "https://huggingface.co/spaces/harumaa/Anime-XL-Studio-And-Qwen-Edit" QWEN_SPACE_URL = os.environ.get("QWEN_SPACE_URL", "https://huggingface.co/spaces/harumaa/Qwen-Image-Editor-SFW-NSFW") UPSCALER_URL = os.environ.get("UPSCALER_URL", "https://huggingface.co/spaces/harumaa/AI-Image-Upscaler-4x") _char_cache = {} # --- Helper Functions --- def fix_vae_dtype(vae, dtype=torch.float16): vae = vae.to(dtype=dtype) vae.config.force_upcast = False return vae def auto_res(input_image): if input_image is None: return gr.update(value=1024), gr.update(value=1024), gr.update(value="Resolution auto-detect ready"), gr.update(value=None) try: img = input_image if isinstance(input_image, Image.Image) else Image.open(input_image) w = max(512, (min(img.width, MAX_IMAGE_SIZE) // 64) * 64) h = max(512, (min(img.height, MAX_IMAGE_SIZE) // 64) * 64) return gr.update(value=w), gr.update(value=h), gr.update(value=f"Auto-detected from image: {w} x {h}"), gr.update(value=None) except Exception: return gr.update(value=1024), gr.update(value=1024), gr.update(value="Error reading image"), gr.update(value=None) def set_quick_ratio(ratio_label): w, h = SDXL_RATIOS.get(ratio_label, (1024, 1024)) return gr.update(value=w), gr.update(value=h), gr.update(value=f"Resolution set by preset: {w} x {h}") def encode_prompts(pipe, prompt, neg): c1 = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=False) c2 = Compel(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True) cd1 = c1(prompt); cd2, pool = c2(prompt) nc1 = c1(neg); nc2, npool = c2(neg) [cd1, nc1] = c1.pad_conditioning_tensors_to_same_length([cd1, nc1]) [cd2, nc2] = c2.pad_conditioning_tensors_to_same_length([cd2, nc2]) expected_dim = pipe.unet.config.cross_attention_dim if cd2.shape[-1] == expected_dim: return cd2, pool, nc2, npool cond_cat = torch.cat([cd1, cd2], dim=-1) ncond_cat = torch.cat([nc1, nc2], dim=-1) if cond_cat.shape[-1] == expected_dim: return cond_cat, pool, ncond_cat, npool if cond_cat.shape[-1] > expected_dim: return cond_cat[..., :expected_dim], pool, ncond_cat[..., :expected_dim], npool pad = expected_dim - cond_cat.shape[-1] return (torch.nn.functional.pad(cond_cat, (0, pad)), pool, torch.nn.functional.pad(ncond_cat, (0, pad)), npool) # --- Preloading Logic --- def download_and_load_all(): print("🚀 [Startup] Initializing Character Studio pre-load...") print("⚙️ [Pipeline] Loading ViT-H Image Encoder...") image_encoder = CLIPVisionModelWithProjection.from_pretrained( "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16 ) for model_name in CIVITAI_MODELS.keys(): try: info = CIVITAI_MODELS[model_name] dest = MODELS_DIR / info["filename"] if not dest.exists(): hf_hub_download(repo_id="harumaa/private-models", filename=info["filename"], local_dir=str(MODELS_DIR), token=HF_TOKEN or None) if model_name not in _char_cache: print(f"⚙️ [Pipeline] Building {model_name} with IP-Adapter Plus...") pipe = StableDiffusionXLPipeline.from_single_file( str(dest), torch_dtype=torch.float16, use_safetensors=True, image_encoder=image_encoder ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) # Loading the High-Detail PLUS model pipe.load_ip_adapter( "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors" ) _char_cache[model_name] = pipe except Exception as e: print(f"❌ [Error] Failed to load {model_name}: {e}") print("✅ [Startup] All models pre-downloaded and pipelines ready.") # --- Inference --- @spaces.GPU(duration=60) def generate_gpu(model_name, ref_img, prompt, neg, seed, randomize, w, h, ip_scale, cfg, steps, progress=gr.Progress(track_tqdm=False)): if ref_img is None: yield None, seed, gr.update(value="⚠️ Error: Please upload a Character Reference image!"), gr.update(visible=False) return if randomize: seed = random.randint(0, MAX_SEED) yield None, seed, gr.update(value="Waiting for GPU..."), gr.update(visible=False, value=None) pipe = _char_cache.get(model_name) if pipe is None: yield None, seed, gr.update(value=f"Model {model_name} not loaded."), gr.update(visible=False) return try: # GPU Prep, VAE Lock, and Watermark removal (The Anti-Deep-Fry Armor) pipe.to("cuda") pipe.vae = fix_vae_dtype(pipe.vae) pipe.watermark = None pipe.set_ip_adapter_scale(float(ip_scale)) generator = torch.Generator("cuda").manual_seed(seed) final_prompt = prompt or "" if "masterpiece" not in final_prompt.lower(): final_prompt = f"{final_prompt}, 1girl, solo, masterpiece, best quality, very aesthetic, absurdres" base_neg = "lowres, (bad), text, error, missing, extra, worst quality, jpeg artifacts, bad anatomy, watermark, unfinished, displeasing, oldest, early, signature, extra digits, fewer digits, bad hands" final_neg = f"{neg}, {base_neg}" if neg else base_neg yield None, seed, gr.update(value=f"Generating {w}x{h}..."), gr.update(visible=False) def cb(p, i, t, kw): progress((i+1)/steps, desc=f"Step {i+1}/{steps}"); return kw cond, pool, ncond, npool = encode_prompts(pipe, final_prompt, final_neg) image = pipe( prompt_embeds=cond, pooled_prompt_embeds=pool, negative_prompt_embeds=ncond, negative_pooled_prompt_embeds=npool, ip_adapter_image=ref_img, guidance_scale=cfg, num_inference_steps=steps, width=int(w), height=int(h), generator=generator, callback_on_step_end=cb ).images[0] yield image, seed, gr.update(value="✅ Done!"), gr.update(visible=True) except Exception as e: import traceback traceback.print_exc() torch.cuda.empty_cache() yield None, seed, gr.update(value=f"❌ Crash: {str(e)[:150]}"), gr.update(visible=False) # --- UI Setup --- def build_css(accent="#7c3aed"): return f""" :root{{--accent:{accent};--accent-dim:color-mix(in srgb,{accent} 20%,transparent);--accent-dark:color-mix(in srgb,{accent} 72%,black);--accent-soft:color-mix(in srgb,{accent} 10%,transparent);}} #col-container{{margin:0 auto;max-width:980px;}} footer{{display:none!important;}} .lbl{{font-weight:700;font-size:11px;letter-spacing:.07em;text-transform:uppercase;color:var(--accent);margin:10px 0 3px 2px;}} .suite-banner{{display:flex;gap:10px;margin:12px 0 10px 0;flex-wrap:wrap;}} .suite-banner a{{flex:1;min-width:150px;text-align:center;background:var(--accent-soft);border:1.5px solid color-mix(in srgb,{accent} 35%,transparent);color:var(--accent)!important;font-weight:700;font-size:13px;padding:10px 14px;border-radius:10px;text-decoration:none!important;transition:all .18s ease;white-space:nowrap;}} .suite-banner a:hover{{background:var(--accent);color:#fff!important;transform:translateY(-1px);box-shadow:0 0 14px var(--accent-dim);}} .suite-banner a.active{{background:var(--accent);color:#fff!important;box-shadow:0 0 16px var(--accent-dim);pointer-events:none;}} .res-status{{font-size:12px;font-weight:600;color:var(--accent);padding:4px 10px;border-radius:6px;background:var(--accent-dim);margin-bottom:8px;min-height:26px;}} button.primary{{background:var(--accent)!important;border-color:var(--accent)!important;transition:all .18s ease;}} button.primary:hover{{background:var(--accent-dark)!important;box-shadow:0 0 14px var(--accent-dim)!important;transform:translateY(-1px);}} .model-btn{{font-size:14px!important;padding:12px 8px!important;border-radius:12px!important;font-weight:600!important;transition:all .18s ease!important;border:2px solid transparent!important;}} .model-btn:hover{{border-color:var(--accent)!important;transform:translateY(-2px)!important;box-shadow:0 4px 16px var(--accent-dim)!important;}} .stop-btn{{background:#c0392b!important;border-color:#c0392b!important;color:#fff!important;}} #ch-out{{border:2px solid var(--accent)!important;border-radius:12px!important;box-shadow:0 0 20px var(--accent-dim);}} .gen-status{{font-size:13px;font-weight:600;padding:6px 10px;border-radius:8px;background:var(--accent-dim);color:var(--accent);margin-bottom:6px;min-height:32px;}} """ def make_model_selector(idx): name = MODEL_NAMES[idx] def _fn(): return (name, gr.update(value=CIVITAI_MODELS[name]["cfg"]), gr.update(value=CIVITAI_MODELS[name]["steps"])) return _fn with gr.Blocks(title="Character Lab", css=build_css(DEFAULT_ACCENT), theme=gr.themes.Soft(primary_hue="purple")) as demo: with gr.Column(elem_id="col-container"): gr.HTML(f'''
''') gr.Markdown("# Character Lab\n**High-detail character consistency from a single image.**") with gr.Row(): ch_btn_m0 = gr.Button(f"{MODEL_NAMES[0]}", variant="primary", scale=1, elem_classes="model-btn") ch_btn_m1 = gr.Button(f"{MODEL_NAMES[1]}", variant="secondary", scale=1, elem_classes="model-btn") ch_model = gr.Textbox(value=MODEL_NAMES[0], visible=False) with gr.Row(): with gr.Column(scale=1): ch_input = gr.Image(label="Full Character Reference", type="pil", height=380) ch_res_status = gr.Markdown("Resolution auto-detect ready", elem_classes="res-status") with gr.Row(): ch_w = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Width") ch_h = gr.Slider(minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024, label="Height") with gr.Column(scale=1): ch_result = gr.Image(show_label=False, elem_id="ch-out", height=500) ch_dl_btn = gr.DownloadButton(label="Download", size="sm", visible=False, elem_classes="dl-btn") with gr.Row(): ch_prompt = gr.Text(show_label=False, max_lines=3, container=False, placeholder="Describe the character, outfit, and scene...") ch_run = gr.Button("Generate", variant="primary", scale=0) ch_stop = gr.Button("Cancel", size="sm", scale=0, elem_classes="stop-btn") with gr.Row(): ch_quick_ratio = gr.Radio(choices=list(SDXL_RATIOS.keys()), label="Quick Pick Resolution Presets", container=True) ch_gen_status = gr.Markdown("Models pre-loading in background...", elem_classes="gen-status") with gr.Accordion("Advanced Settings", open=False): ch_neg = gr.Text(label="Negative Prompt", max_lines=2, value="lowres, bad anatomy, bad hands, text, watermark, worst quality") # The Sweet Spot for IP-Adapter Plus on Anime models ch_ip_scale = gr.Slider(label="Consistency Strength (IP-Scale)", minimum=0.1, maximum=1.0, step=0.05, value=0.2) with gr.Row(): ch_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) ch_rand = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): ch_cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=7.0) ch_steps = gr.Slider(label="Steps", minimum=10, maximum=60, step=1, value=30) for i, btn in enumerate([ch_btn_m0, ch_btn_m1]): btn.click(fn=make_model_selector(i), inputs=[], outputs=[ch_model, ch_cfg, ch_steps]) ch_input.change(fn=auto_res, inputs=[ch_input], outputs=[ch_w, ch_h, ch_res_status, ch_quick_ratio]) ch_quick_ratio.change(fn=set_quick_ratio, inputs=[ch_quick_ratio], outputs=[ch_w, ch_h, ch_res_status]) ch_event = gr.on( triggers=[ch_run.click, ch_prompt.submit], fn=generate_gpu, inputs=[ch_model, ch_input, ch_prompt, ch_neg, ch_seed, ch_rand, ch_w, ch_h, ch_ip_scale, ch_cfg, ch_steps], outputs=[ch_result, ch_seed, ch_gen_status, ch_dl_btn] ) ch_stop.click(fn=None, cancels=[ch_event], queue=False) threading.Thread(target=download_and_load_all, daemon=True).start() if __name__ == "__main__": demo.queue(max_size=20, default_concurrency_limit=1).launch(ssr_mode=False)