| import os, uuid, gradio as gr, torch |
| from pathlib import Path |
| from diffusers import StableDiffusionXLPipeline |
| from huggingface_hub import snapshot_download |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
| |
| BASE_REPO = "SG161222/RealVisXL_V5.0" |
| print(f"⏳ Loading SDXL base on {DEVICE} …") |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| BASE_REPO, |
| torch_dtype=DTYPE, |
| variant="fp16", |
| use_safetensors=True, |
| add_watermarker=False, |
| ).to(DEVICE) |
|
|
|
|
|
|
| |
| LORA_REPO = "pdbdb/lora-act" |
| snapshot_path = Path( |
| snapshot_download( |
| repo_id=LORA_REPO, |
| allow_patterns="ACT_textencoded_locon_more_epochs-*.safetensors", |
| token=os.getenv("HF_TOKEN"), |
| ) |
| ) |
| LORA_MAP = {"none": None, **{p.name: str(p) for p in snapshot_path.rglob("*.safetensors")}} |
| LORA_LABELS = list(LORA_MAP.keys()) |
|
|
| NEG_PROMPT = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" |
|
|
| @torch.inference_mode() |
| def generate(prompt, lora_label, lora_weight, guidance, steps, seed): |
| torch.cuda.empty_cache() |
| pipe.unload_lora_weights() |
| if "portret" not in prompt.lower(): |
| prompt = f"brndsht, photo portret, {prompt}, outdoor sunlit setting, bright natural light, vivid colors, shallow depth-of-field, candid lifestyle vibe, high-resolution DSLR photo" |
| else: |
| prompt = f"brndsht, {prompt}, outdoor sunlit setting, bright natural light, vivid colors, shallow depth-of-field, candid lifestyle vibe, high-resolution DSLR photo" |
|
|
|
|
| if lora_label != "none": |
| pipe.load_lora_weights( |
| LORA_MAP[lora_label], |
| weight=float(lora_weight), |
| adapter_name=f"live_{uuid.uuid4().hex[:6]}", |
| ) |
|
|
| g = torch.Generator("cuda").manual_seed(int(seed)) |
| img = pipe(prompt, |
| negative_prompt=NEG_PROMPT, |
| num_inference_steps=int(steps), |
| guidance_scale=float(guidance), |
| generator=g).images[0] |
| return img |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## SDXL LoRa ACT Playground") |
| with gr.Row(): |
| prompt = gr.Textbox(label="Prompt", scale=4) |
| run_btn = gr.Button("Generate", variant="primary") |
| with gr.Row(): |
| lora = gr.Dropdown(label="LoRA checkpoint", choices=LORA_LABELS, value="none") |
| weight = gr.Slider(0, 1.2, 1.0, 0.05, label="LoRA weight") |
| guide = gr.Slider(2, 15, 6.5, 0.5, label="Guidance scale") |
| steps = gr.Slider(20, 150, 40, 5, label="Steps") |
| seed = gr.Number(42, label="Seed", precision=0) |
| out = gr.Image(height=512) |
| run_btn.click(generate, [prompt, lora, weight, guide, steps, seed], out) |
|
|
| demo.launch() |
|
|