File size: 3,401 Bytes
1b2af48
5947edc
 
1b2af48
5947edc
1b2af48
 
 
6dd88e5
1b2af48
 
 
6dd88e5
1b2af48
 
 
 
 
 
5947edc
6dd88e5
1b2af48
 
 
5947edc
1b2af48
 
53feead
1b2af48
 
 
 
5947edc
 
 
 
 
 
 
 
53feead
 
 
 
 
5947edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53feead
5947edc
 
 
 
 
 
 
f3f208a
5947edc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os, uuid, gradio as gr, torch
from pathlib import Path
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import snapshot_download     # ← auth-aware helper

# ---------------- runtime & dtype ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE == "cuda" else torch.float32

# ---------------- base model ---------------------
BASE_REPO = "SG161222/RealVisXL_V5.0"          # full SDXL checkpoint
print(f"⏳ Loading SDXL base on {DEVICE} …")
pipe = StableDiffusionXLPipeline.from_pretrained(
        BASE_REPO,
        torch_dtype=DTYPE,
        variant="fp16",
        use_safetensors=True,
        add_watermarker=False,
).to(DEVICE)



# ---------------- LoRA zoo -----------------------
LORA_REPO = "pdbdb/lora-act"
snapshot_path = Path(
    snapshot_download(
        repo_id=LORA_REPO,
        allow_patterns="ACT_textencoded_locon_more_epochs-*.safetensors",       # only pull LoRA files
        token=os.getenv("HF_TOKEN"),          # None = public repo
    )
)
LORA_MAP = {"none": None, **{p.name: str(p) for p in snapshot_path.rglob("*.safetensors")}}
LORA_LABELS = list(LORA_MAP.keys())

NEG_PROMPT = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"

@torch.inference_mode()
def generate(prompt, lora_label, lora_weight, guidance, steps, seed):
    torch.cuda.empty_cache()
    pipe.unload_lora_weights()
    if "portret" not in prompt.lower():
        prompt = f"brndsht, photo portret, {prompt}, outdoor sunlit setting, bright natural light, vivid colors, shallow depth-of-field, candid lifestyle vibe, high-resolution DSLR photo"
    else:
        prompt = f"brndsht, {prompt}, outdoor sunlit setting, bright natural light, vivid colors, shallow depth-of-field, candid lifestyle vibe, high-resolution DSLR photo"


    if lora_label != "none":
        pipe.load_lora_weights(
            LORA_MAP[lora_label],
            weight=float(lora_weight),
            adapter_name=f"live_{uuid.uuid4().hex[:6]}",
        )

    g = torch.Generator("cuda").manual_seed(int(seed))
    img = pipe(prompt,
               negative_prompt=NEG_PROMPT,
               num_inference_steps=int(steps),
               guidance_scale=float(guidance),
               generator=g).images[0]
    return img

with gr.Blocks() as demo:
    gr.Markdown("## SDXL LoRa ACT Playground")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", scale=4)
        run_btn = gr.Button("Generate", variant="primary")
    with gr.Row():
        lora   = gr.Dropdown(label="LoRA checkpoint", choices=LORA_LABELS, value="none")
        weight = gr.Slider(0, 1.2, 1.0, 0.05, label="LoRA weight")
        guide  = gr.Slider(2, 15, 6.5, 0.5, label="Guidance scale")
        steps  = gr.Slider(20, 150, 40, 5, label="Steps")
        seed   = gr.Number(42, label="Seed", precision=0)
    out = gr.Image(height=512)
    run_btn.click(generate, [prompt, lora, weight, guide, steps, seed], out)

demo.launch()