DB2169 commited on
Commit
5e98324
·
verified ·
1 Parent(s): 6f1a16e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -65
app.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, io, json, base64
2
  from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
- from huggingface_hub import snapshot_download
7
  from diffusers import (
8
  StableDiffusionXLPipeline,
9
  StableDiffusionPipeline,
@@ -15,11 +15,12 @@ from diffusers import (
15
  PNDMScheduler,
16
  )
17
 
18
- # ---- Config via Space Secrets / Environment ----
19
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora")
20
- CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors")
21
- HF_TOKEN = os.getenv("HF_TOKEN", None)
22
 
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  dtype = torch.float16 if device == "cuda" else torch.float32
25
 
@@ -33,64 +34,71 @@ SCHEDULERS = {
33
  "dpmpp_2m": DPMSolverMultistepScheduler,
34
  }
35
 
36
- # Globals populated at startup
37
  pipe = None
38
  IS_SDXL = True
39
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
 
40
 
 
41
  def bootstrap_model():
42
  global pipe, IS_SDXL, LORA_MANIFEST
43
-
44
- # Download your model repo (includes base .safetensors and loras.json)
45
- repo_dir = snapshot_download(repo_id=MODEL_REPO_ID, token=HF_TOKEN, local_dir="/home/user/model", ignore_patterns=["*.md"])
46
- ckpt_path = os.path.join(repo_dir, CHECKPOINT_FILENAME)
 
 
 
 
 
47
  if not os.path.exists(ckpt_path):
48
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
49
 
50
- # Load SDXL first; fallback to SD 1.x/2.x single-file
51
  try:
52
- pipe_local = StableDiffusionXLPipeline.from_single_file(
53
  ckpt_path, torch_dtype=dtype, use_safetensors=True, add_watermarker=False
54
- )
55
- IS_SDXL = True
56
  except Exception:
57
- pipe_local = StableDiffusionPipeline.from_single_file(
58
  ckpt_path, torch_dtype=dtype, use_safetensors=True
59
- )
60
- IS_SDXL = False
61
-
62
- if hasattr(pipe_local, "enable_attention_slicing"):
63
- pipe_local.enable_attention_slicing("max")
64
- if hasattr(pipe_local, "enable_vae_slicing"):
65
- pipe_local.enable_vae_slicing()
66
- if hasattr(pipe_local, "set_progress_bar_config"):
67
- pipe_local.set_progress_bar_config(disable=True)
68
- pipe_local.to(device)
69
-
70
- # Load LoRA manifest (optional)
71
- man_path = os.path.join(repo_dir, "loras.json")
 
72
  if os.path.exists(man_path):
73
  try:
74
  with open(man_path, "r", encoding="utf-8") as f:
75
- LORA_MANIFEST = json.load(f)
76
- except Exception:
77
- LORA_MANIFEST = {}
78
- else:
79
- LORA_MANIFEST = {}
80
 
81
- return pipe_local
 
82
 
83
  def apply_loras(selected: List[str], scale: float):
84
  if not selected or scale <= 0:
85
  return
 
86
  for name in selected:
87
  meta = LORA_MANIFEST.get(name)
88
  if not meta:
89
  continue
90
  try:
91
  if "path" in meta:
92
- # If you later store LoRA files inside the model repo
93
- pipe.load_lora_weights(os.path.join("/home/user/model", meta["path"]), adapter_name=name)
94
  else:
95
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
96
  except Exception as e:
@@ -114,12 +122,14 @@ def txt2img(
114
  lora_scale: float,
115
  fuse_lora: bool,
116
  ):
117
- if scheduler and scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
 
118
  try:
119
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
120
  except Exception as e:
121
  print(f"[WARN] Scheduler switch failed: {e}")
122
 
 
123
  apply_loras(loras, lora_scale)
124
  if fuse_lora and loras:
125
  try:
@@ -127,7 +137,9 @@ def txt2img(
127
  except Exception as e:
128
  print(f"[WARN] fuse_lora failed: {e}")
129
 
 
130
  generator = torch.Generator(device=device).manual_seed(int(seed)) if seed not in (None, "") else None
 
131
  kwargs: Dict[str, Any] = dict(
132
  prompt=prompt or "",
133
  negative_prompt=negative or None,
@@ -142,18 +154,15 @@ def txt2img(
142
  return out.images
143
 
144
  def warmup():
145
- # Tiny, fast pass to initialize CUDA kernels and weight graphs
146
  try:
147
- _ = txt2img(
148
- "warmup", "", 512 if IS_SDXL else 512, 512 if IS_SDXL else 512,
149
- 5, 5.0, 1, 1234, "default", [], 0.0, False
150
- )
151
  except Exception as e:
152
- print(f"[WARN] warmup failed: {e}")
153
 
154
- # Build UI
155
- with gr.Blocks(title="SDXL Space (runs Diffusers directly)") as demo:
156
- gr.Markdown("### SDXL text‑to‑image (single‑file checkpoint) with optional LoRAs")
157
  with gr.Row():
158
  prompt = gr.Textbox(label="Prompt", lines=3)
159
  negative = gr.Textbox(label="Negative Prompt", lines=3)
@@ -168,36 +177,33 @@ with gr.Blocks(title="SDXL Space (runs Diffusers directly)") as demo:
168
  seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
169
  scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
170
 
171
- # LoRA controls: multi-select from loras.json
172
- lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
173
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
174
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
175
 
176
  btn = gr.Button("Generate", variant="primary")
177
  gallery = gr.Gallery(columns=4, height=420)
178
 
 
179
  def _startup():
180
- global pipe
181
- pipe = bootstrap_model()
182
- # Fill LoRA choices after manifest loads
183
  return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
184
- demo.load(_startup, outputs=[lora_names])
185
- demo.load(lambda: warmup(), inputs=None, outputs=None)
186
 
 
 
 
 
187
  btn.click(
188
  txt2img,
189
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
190
  outputs=[gallery],
191
  api_name="txt2img",
192
- )
 
 
193
 
194
- # Tune queue for Spaces (concurrency + backlog)
195
- btn.click(
196
- txt2img,
197
- inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
198
- outputs=[gallery],
199
- api_name="txt2img",
200
- concurrency_limit=1, # keep 1 GPU job at a time
201
- concurrency_id="gpu_queue" # share queue if you add more GPU events later
202
- )
203
- demo.queue(max_size=32, default_concurrency_limit=1).launch()
 
1
+ import os, io, json
2
  from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
+ from huggingface_hub import snapshot_download # pulls your repo at startup
7
  from diffusers import (
8
  StableDiffusionXLPipeline,
9
  StableDiffusionPipeline,
 
15
  PNDMScheduler,
16
  )
17
 
18
+ # -------- Configuration (set these in Space Secrets for private repos) --------
19
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora") # e.g., your repo id
20
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors") # exact base ckpt filename
21
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # optional if repo is public
22
 
23
+ # -------- Runtime defaults --------
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  dtype = torch.float16 if device == "cuda" else torch.float32
26
 
 
34
  "dpmpp_2m": DPMSolverMultistepScheduler,
35
  }
36
 
37
+ # Globals filled on startup
38
  pipe = None
39
  IS_SDXL = True
40
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
41
+ REPO_DIR = "/home/user/model" # cached snapshot location in Spaces
42
 
43
+ # -------- Model bootstrap --------
44
  def bootstrap_model():
45
  global pipe, IS_SDXL, LORA_MANIFEST
46
+ # Download/copy all repo files locally (weights + manifest)
47
+ local_dir = snapshot_download(
48
+ repo_id=MODEL_REPO_ID,
49
+ token=HF_TOKEN,
50
+ local_dir=REPO_DIR,
51
+ ignore_patterns=["*.md"],
52
+ ) # downloads your model repo into the container cache [web:362]
53
+
54
+ ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
55
  if not os.path.exists(ckpt_path):
56
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
57
 
58
+ # Try SDXL single-file, then SD 1.x/2.x single-file
59
  try:
60
+ _pipe = StableDiffusionXLPipeline.from_single_file(
61
  ckpt_path, torch_dtype=dtype, use_safetensors=True, add_watermarker=False
62
+ ) # SDXL loader [web:104]
63
+ sdxl = True
64
  except Exception:
65
+ _pipe = StableDiffusionPipeline.from_single_file(
66
  ckpt_path, torch_dtype=dtype, use_safetensors=True
67
+ ) # SD 1.x/2.x fallback [web:104]
68
+ sdxl = False
69
+
70
+ if hasattr(_pipe, "enable_attention_slicing"):
71
+ _pipe.enable_attention_slicing("max")
72
+ if hasattr(_pipe, "enable_vae_slicing"):
73
+ _pipe.enable_vae_slicing()
74
+ if hasattr(_pipe, "set_progress_bar_config"):
75
+ _pipe.set_progress_bar_config(disable=True)
76
+ _pipe.to(device)
77
+
78
+ # Load LoRA manifest if present
79
+ man_path = os.path.join(local_dir, "loras.json")
80
+ manifest = {}
81
  if os.path.exists(man_path):
82
  try:
83
  with open(man_path, "r", encoding="utf-8") as f:
84
+ manifest = json.load(f)
85
+ except Exception as e:
86
+ print(f"[WARN] Failed to parse loras.json: {e}")
 
 
87
 
88
+ # Publish globals
89
+ return _pipe, sdxl, manifest
90
 
91
  def apply_loras(selected: List[str], scale: float):
92
  if not selected or scale <= 0:
93
  return
94
+ # Each selected LoRA should exist in manifest; supports repo/weight_name or local 'path'
95
  for name in selected:
96
  meta = LORA_MANIFEST.get(name)
97
  if not meta:
98
  continue
99
  try:
100
  if "path" in meta:
101
+ pipe.load_lora_weights(os.path.join(REPO_DIR, meta["path"]), adapter_name=name)
 
102
  else:
103
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
104
  except Exception as e:
 
122
  lora_scale: float,
123
  fuse_lora: bool,
124
  ):
125
+ # Scheduler swap
126
+ if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
127
  try:
128
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
129
  except Exception as e:
130
  print(f"[WARN] Scheduler switch failed: {e}")
131
 
132
+ # Apply LoRAs
133
  apply_loras(loras, lora_scale)
134
  if fuse_lora and loras:
135
  try:
 
137
  except Exception as e:
138
  print(f"[WARN] fuse_lora failed: {e}")
139
 
140
+ # Determinism
141
  generator = torch.Generator(device=device).manual_seed(int(seed)) if seed not in (None, "") else None
142
+
143
  kwargs: Dict[str, Any] = dict(
144
  prompt=prompt or "",
145
  negative_prompt=negative or None,
 
154
  return out.images
155
 
156
  def warmup():
157
+ # Small, fast call to initialize kernels/graphs so first user is instant
158
  try:
159
+ _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
 
 
 
160
  except Exception as e:
161
+ print(f"[WARN] Warmup failed: {e}")
162
 
163
+ # --------------------------- Build the UI inside Blocks ---------------------------
164
+ with gr.Blocks(title="SDXL Space (single-file, LoRA-ready)") as demo: # Blocks context required for events [web:371]
165
+ gr.Markdown("### SDXL text‑to‑image (single‑file checkpoint) with optional LoRAs") # UI heading [web:147]
166
  with gr.Row():
167
  prompt = gr.Textbox(label="Prompt", lines=3)
168
  negative = gr.Textbox(label="Negative Prompt", lines=3)
 
177
  seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
178
  scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
179
 
180
+ # LoRA multi-select populated after manifest loads
181
+ lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json)")
182
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
183
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
184
 
185
  btn = gr.Button("Generate", variant="primary")
186
  gallery = gr.Gallery(columns=4, height=420)
187
 
188
+ # Startup loader (runs at app load)
189
  def _startup():
190
+ global pipe, IS_SDXL, LORA_MANIFEST
191
+ pipe, IS_SDXL, LORA_MANIFEST = bootstrap_model()
 
192
  return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
193
+ demo.load(_startup, outputs=[lora_names]) # fill LoRA list once model is ready [web:147]
 
194
 
195
+ # Warm-up pass after model load for snappy first request
196
+ demo.load(lambda: warmup(), inputs=None, outputs=None) # performance warmup [web:356]
197
+
198
+ # Wire the button click inside Blocks, with per-event concurrency control
199
  btn.click(
200
  txt2img,
201
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
202
  outputs=[gallery],
203
  api_name="txt2img",
204
+ concurrency_limit=1, # one GPU job at a time for SDXL
205
+ concurrency_id="gpu_queue", # shared queue id if you add more GPU events
206
+ ) # per-event queue parameters in Gradio 4.x [web:388][web:373]
207
 
208
+ # Global queue config (no deprecated args)
209
+ demo.queue(max_size=32, default_concurrency_limit=1).launch() # supported queue pattern in Gradio 4.x [web:373][web:381]