DB2169 commited on
Commit
43454fc
·
verified ·
1 Parent(s): aac0073

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, base64
2
+ from typing import List, Dict, Any, Optional
3
+ from PIL import Image
4
+ import torch
5
+ import gradio as gr
6
+ from huggingface_hub import snapshot_download
7
+ from diffusers import (
8
+ StableDiffusionXLPipeline,
9
+ StableDiffusionPipeline,
10
+ DPMSolverMultistepScheduler,
11
+ EulerAncestralDiscreteScheduler,
12
+ EulerDiscreteScheduler,
13
+ DDIMScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+
18
+ # ---- Config via Space Secrets / Environment ----
19
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora")
20
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors")
21
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ dtype = torch.float16 if device == "cuda" else torch.float32
25
+
26
+ SCHEDULERS = {
27
+ "default": None,
28
+ "euler_a": EulerAncestralDiscreteScheduler,
29
+ "euler": EulerDiscreteScheduler,
30
+ "ddim": DDIMScheduler,
31
+ "lms": LMSDiscreteScheduler,
32
+ "pndm": PNDMScheduler,
33
+ "dpmpp_2m": DPMSolverMultistepScheduler,
34
+ }
35
+
36
+ # Globals populated at startup
37
+ pipe = None
38
+ IS_SDXL = True
39
+ LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
40
+
41
+ def bootstrap_model():
42
+ global pipe, IS_SDXL, LORA_MANIFEST
43
+
44
+ # Download your model repo (includes base .safetensors and loras.json)
45
+ repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, token=HF_TOKEN, local_dir="/home/user/model", ignore_patterns=["*.md"])
46
+ ckpt_path = os.path.join(repo_dir, CHECKPOINT_FILENAME)
47
+ if not os.path.exists(ckpt_path):
48
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
49
+
50
+ # Load SDXL first; fallback to SD 1.x/2.x single-file
51
+ try:
52
+ pipe_local = StableDiffusionXLPipeline.from_single_file(
53
+ ckpt_path, torch_dtype=dtype, use_safetensors=True, add_watermarker=False
54
+ )
55
+ IS_SDXL = True
56
+ except Exception:
57
+ pipe_local = StableDiffusionPipeline.from_single_file(
58
+ ckpt_path, torch_dtype=dtype, use_safetensors=True
59
+ )
60
+ IS_SDXL = False
61
+
62
+ if hasattr(pipe_local, "enable_attention_slicing"):
63
+ pipe_local.enable_attention_slicing("max")
64
+ if hasattr(pipe_local, "enable_vae_slicing"):
65
+ pipe_local.enable_vae_slicing()
66
+ if hasattr(pipe_local, "set_progress_bar_config"):
67
+ pipe_local.set_progress_bar_config(disable=True)
68
+ pipe_local.to(device)
69
+
70
+ # Load LoRA manifest (optional)
71
+ man_path = os.path.join(repo_dir, "loras.json")
72
+ if os.path.exists(man_path):
73
+ try:
74
+ with open(man_path, "r", encoding="utf-8") as f:
75
+ LORA_MANIFEST = json.load(f)
76
+ except Exception:
77
+ LORA_MANIFEST = {}
78
+ else:
79
+ LORA_MANIFEST = {}
80
+
81
+ return pipe_local
82
+
83
+ def apply_loras(selected: List[str], scale: float):
84
+ if not selected or scale <= 0:
85
+ return
86
+ for name in selected:
87
+ meta = LORA_MANIFEST.get(name)
88
+ if not meta:
89
+ continue
90
+ try:
91
+ if "path" in meta:
92
+ # If you later store LoRA files inside the model repo
93
+ pipe.load_lora_weights(os.path.join("/home/user/model", meta["path"]), adapter_name=name)
94
+ else:
95
+ pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
96
+ except Exception as e:
97
+ print(f"[WARN] LoRA load failed for {name}: {e}")
98
+ try:
99
+ pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
100
+ except Exception as e:
101
+ print(f"[WARN] set_adapters failed: {e}")
102
+
103
+ def txt2img(
104
+ prompt: str,
105
+ negative: str,
106
+ width: int,
107
+ height: int,
108
+ steps: int,
109
+ guidance: float,
110
+ images: int,
111
+ seed: Optional[int],
112
+ scheduler: str,
113
+ loras: List[str],
114
+ lora_scale: float,
115
+ fuse_lora: bool,
116
+ ):
117
+ if scheduler and scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
118
+ try:
119
+ pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
120
+ except Exception as e:
121
+ print(f"[WARN] Scheduler switch failed: {e}")
122
+
123
+ apply_loras(loras, lora_scale)
124
+ if fuse_lora and loras:
125
+ try:
126
+ pipe.fuse_lora(lora_scale=float(lora_scale))
127
+ except Exception as e:
128
+ print(f"[WARN] fuse_lora failed: {e}")
129
+
130
+ generator = torch.Generator(device=device).manual_seed(int(seed)) if seed not in (None, "") else None
131
+ kwargs: Dict[str, Any] = dict(
132
+ prompt=prompt or "",
133
+ negative_prompt=negative or None,
134
+ width=int(width),
135
+ height=int(height),
136
+ num_inference_steps=int(steps),
137
+ guidance_scale=float(guidance),
138
+ num_images_per_prompt=int(images),
139
+ generator=generator,
140
+ )
141
+ out = pipe(**kwargs)
142
+ return out.images
143
+
144
+ def warmup():
145
+ # Tiny, fast pass to initialize CUDA kernels and weight graphs
146
+ try:
147
+ _ = txt2img(
148
+ "warmup", "", 512 if IS_SDXL else 512, 512 if IS_SDXL else 512,
149
+ 5, 5.0, 1, 1234, "default", [], 0.0, False
150
+ )
151
+ except Exception as e:
152
+ print(f"[WARN] warmup failed: {e}")
153
+
154
+ # Build UI
155
+ with gr.Blocks(title="SDXL Space (runs Diffusers directly)") as demo:
156
+ gr.Markdown("### SDXL text‑to‑image (single‑file checkpoint) with optional LoRAs")
157
+ with gr.Row():
158
+ prompt = gr.Textbox(label="Prompt", lines=3)
159
+ negative = gr.Textbox(label="Negative Prompt", lines=3)
160
+ with gr.Row():
161
+ width = gr.Slider(256, 1536, 1024, step=64, label="Width")
162
+ height = gr.Slider(256, 1536, 1024, step=64, label="Height")
163
+ with gr.Row():
164
+ steps = gr.Slider(5, 80, 30, step=1, label="Steps")
165
+ guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
166
+ images = gr.Slider(1, 4, 1, step=1, label="Images")
167
+ with gr.Row():
168
+ seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
169
+ scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
170
+
171
+ # LoRA controls: multi-select from loras.json
172
+ lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
173
+ lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
174
+ fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
175
+
176
+ btn = gr.Button("Generate", variant="primary")
177
+ gallery = gr.Gallery(columns=4, height=420)
178
+
179
+ def _startup():
180
+ global pipe
181
+ pipe = bootstrap_model()
182
+ # Fill LoRA choices after manifest loads
183
+ return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
184
+ demo.load(_startup, outputs=[lora_names])
185
+ demo.load(lambda: warmup(), inputs=None, outputs=None)
186
+
187
+ btn.click(
188
+ txt2img,
189
+ inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
190
+ outputs=[gallery],
191
+ api_name="txt2img",
192
+ )
193
+
194
+ # Tune queue for Spaces (concurrency + backlog)
195
+ demo.queue(max_size=32, concurrency_count=2).launch()