DB2169 commited on
Commit
9441bd6
·
verified ·
1 Parent(s): 4f399ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -45
app.py CHANGED
@@ -3,8 +3,8 @@ from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
- import spaces # ZeroGPU: decorate GPU-bound functions
7
- from huggingface_hub import snapshot_download
8
  from diffusers import (
9
  StableDiffusionXLPipeline,
10
  StableDiffusionPipeline,
@@ -16,15 +16,12 @@ from diffusers import (
16
  PNDMScheduler,
17
  )
18
 
19
- # -------- Configuration (set as Space Secrets if needed) --------
20
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora") # your model repo id
21
- CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors") # exact .safetensors name
22
- HF_TOKEN = os.getenv("HF_TOKEN", None) # only required for private repos
23
- DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set to "0" to disable warmup
24
 
25
- # -------- Runtime defaults --------
26
- REPO_DIR = "/home/user/model" # local cache mount for snapshot_download
27
- # Defer CUDA detection to GPU-run function for ZeroGPU; do not move to CUDA at import time
28
 
29
  SCHEDULERS = {
30
  "default": None,
@@ -36,24 +33,40 @@ SCHEDULERS = {
36
  "dpmpp_2m": DPMSolverMultistepScheduler,
37
  }
38
 
39
- # Globals populated on startup
40
  pipe = None
41
  IS_SDXL = True
42
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
 
43
 
44
-
45
- # -------- Model bootstrap (CPU) --------
46
  def bootstrap_model():
47
- global pipe, IS_SDXL, LORA_MANIFEST
48
- local_dir = snapshot_download(
49
- repo_id=MODEL_REPO_ID,
50
- token=HF_TOKEN,
51
- local_dir=REPO_DIR,
52
- ignore_patterns=["*.md"],
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
55
  if not os.path.exists(ckpt_path):
56
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
 
 
57
 
58
  try:
59
  _pipe = StableDiffusionXLPipeline.from_single_file(
@@ -61,12 +74,16 @@ def bootstrap_model():
61
  )
62
  sdxl = True
63
  except Exception:
64
- _pipe = StableDiffusionPipeline.from_single_file(
65
- ckpt_path, torch_dtype=torch.float16, use_safetensors=True
66
- )
67
- sdxl = False
 
 
 
 
 
68
 
69
- # Keep on CPU until GPU-decorated call (ZeroGPU attaches GPU on demand)
70
  if hasattr(_pipe, "enable_attention_slicing"):
71
  _pipe.enable_attention_slicing("max")
72
  if hasattr(_pipe, "enable_vae_slicing"):
@@ -83,11 +100,12 @@ def bootstrap_model():
83
  except Exception as e:
84
  print(f"[WARN] Failed to parse loras.json: {e}")
85
 
 
 
86
  pipe = _pipe
87
  IS_SDXL = sdxl
88
  LORA_MANIFEST = manifest
89
 
90
-
91
  def apply_loras(selected: List[str], scale: float, repo_dir: str):
92
  if not selected or scale <= 0:
93
  return
@@ -107,8 +125,7 @@ def apply_loras(selected: List[str], scale: float, repo_dir: str):
107
  except Exception as e:
108
  print(f"[WARN] set_adapters failed: {e}")
109
 
110
-
111
- @spaces.GPU # ZeroGPU: allocate/attach GPU for this function call
112
  def txt2img(
113
  prompt: str,
114
  negative: str,
@@ -123,19 +140,19 @@ def txt2img(
123
  lora_scale: float,
124
  fuse_lora: bool,
125
  ):
126
- # Resolve device inside GPU context
 
 
127
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
128
  local_dtype = torch.float16 if local_device == "cuda" else torch.float32
129
  pipe.to(local_device)
130
 
131
- # Scheduler swap
132
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
133
  try:
134
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
135
  except Exception as e:
136
  print(f"[WARN] Scheduler switch failed: {e}")
137
 
138
- # LoRAs
139
  apply_loras(loras, lora_scale, REPO_DIR)
140
  if fuse_lora and loras:
141
  try:
@@ -143,7 +160,6 @@ def txt2img(
143
  except Exception as e:
144
  print(f"[WARN] fuse_lora failed: {e}")
145
 
146
- # Determinism
147
  generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
148
 
149
  kwargs: Dict[str, Any] = dict(
@@ -160,17 +176,14 @@ def txt2img(
160
  out = pipe(**kwargs)
161
  return out.images
162
 
163
-
164
  def warmup():
165
  try:
166
  _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
167
  except Exception as e:
168
  print(f"[WARN] Warmup failed: {e}")
169
 
170
-
171
- # --------------------------- Build UI ---------------------------
172
- with gr.Blocks(title="SDXL Space (ZeroGPU, single-file checkpoint, LoRA-ready)") as demo:
173
- gr.Markdown("### SDXL text‑to‑image with single‑file checkpoint and optional LoRAs")
174
 
175
  with gr.Row():
176
  prompt = gr.Textbox(label="Prompt", lines=3)
@@ -193,21 +206,21 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file checkpoint, LoRA-ready)")
193
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
194
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
195
 
196
- btn = gr.Button("Generate", variant="primary")
197
  gallery = gr.Gallery(columns=4, height=420)
198
 
199
- # Load model + manifest, then populate LoRA choices
200
  def _startup():
201
  bootstrap_model()
202
- return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
 
 
 
203
 
204
- demo.load(_startup, outputs=[lora_names])
205
 
206
- # Optional warmup (costs a tiny GPU run on first boot); set WARMUP=0 to skip
207
  if DO_WARMUP:
208
  demo.load(lambda: warmup(), inputs=None, outputs=None)
209
 
210
- # Event binding inside Blocks; one GPU job at a time for SDXL
211
  btn.click(
212
  txt2img,
213
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
@@ -217,5 +230,4 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file checkpoint, LoRA-ready)")
217
  concurrency_id="gpu_queue",
218
  )
219
 
220
- # Global queue limits for Gradio 4.x
221
  demo.queue(max_size=32, default_concurrency_limit=1).launch()
 
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
+ import spaces
7
+ from huggingface_hub import snapshot_download, HfHubHTTPError
8
  from diffusers import (
9
  StableDiffusionXLPipeline,
10
  StableDiffusionPipeline,
 
16
  PNDMScheduler,
17
  )
18
 
19
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "").strip()
20
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "").strip()
21
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
22
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1"
 
23
 
24
+ REPO_DIR = "/home/user/model"
 
 
25
 
26
  SCHEDULERS = {
27
  "default": None,
 
33
  "dpmpp_2m": DPMSolverMultistepScheduler,
34
  }
35
 
 
36
  pipe = None
37
  IS_SDXL = True
38
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
39
+ INIT_ERROR: Optional[str] = None # expose bootstrap error to UI
40
 
 
 
41
  def bootstrap_model():
42
+ global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
43
+ INIT_ERROR = None
44
+ if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
45
+ INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME environment variables."
46
+ print(f"[ERROR] {INIT_ERROR}")
47
+ return
48
+
49
+ try:
50
+ local_dir = snapshot_download(
51
+ repo_id=MODEL_REPO_ID,
52
+ token=HF_TOKEN,
53
+ local_dir=REPO_DIR,
54
+ ignore_patterns=["*.md"],
55
+ )
56
+ except HfHubHTTPError as e:
57
+ INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
58
+ print(f"[ERROR] {INIT_ERROR}")
59
+ return
60
+ except Exception as e:
61
+ INIT_ERROR = f"Unexpected error while downloading repo: {e}"
62
+ print(f"[ERROR] {INIT_ERROR}")
63
+ return
64
+
65
  ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
66
  if not os.path.exists(ckpt_path):
67
+ INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME."
68
+ print(f"[ERROR] {INIT_ERROR}")
69
+ return
70
 
71
  try:
72
  _pipe = StableDiffusionXLPipeline.from_single_file(
 
74
  )
75
  sdxl = True
76
  except Exception:
77
+ try:
78
+ _pipe = StableDiffusionPipeline.from_single_file(
79
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True
80
+ )
81
+ sdxl = False
82
+ except Exception as e:
83
+ INIT_ERROR = f"Failed to load pipeline from {CHECKPOINT_FILENAME}: {e}"
84
+ print(f"[ERROR] {INIT_ERROR}")
85
+ return
86
 
 
87
  if hasattr(_pipe, "enable_attention_slicing"):
88
  _pipe.enable_attention_slicing("max")
89
  if hasattr(_pipe, "enable_vae_slicing"):
 
100
  except Exception as e:
101
  print(f"[WARN] Failed to parse loras.json: {e}")
102
 
103
+ # publish
104
+ global pipe, IS_SDXL, LORA_MANIFEST
105
  pipe = _pipe
106
  IS_SDXL = sdxl
107
  LORA_MANIFEST = manifest
108
 
 
109
  def apply_loras(selected: List[str], scale: float, repo_dir: str):
110
  if not selected or scale <= 0:
111
  return
 
125
  except Exception as e:
126
  print(f"[WARN] set_adapters failed: {e}")
127
 
128
+ @spaces.GPU
 
129
  def txt2img(
130
  prompt: str,
131
  negative: str,
 
140
  lora_scale: float,
141
  fuse_lora: bool,
142
  ):
143
+ if pipe is None:
144
+ raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
145
+
146
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
147
  local_dtype = torch.float16 if local_device == "cuda" else torch.float32
148
  pipe.to(local_device)
149
 
 
150
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
151
  try:
152
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
153
  except Exception as e:
154
  print(f"[WARN] Scheduler switch failed: {e}")
155
 
 
156
  apply_loras(loras, lora_scale, REPO_DIR)
157
  if fuse_lora and loras:
158
  try:
 
160
  except Exception as e:
161
  print(f"[WARN] fuse_lora failed: {e}")
162
 
 
163
  generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
164
 
165
  kwargs: Dict[str, Any] = dict(
 
176
  out = pipe(**kwargs)
177
  return out.images
178
 
 
179
  def warmup():
180
  try:
181
  _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
182
  except Exception as e:
183
  print(f"[WARN] Warmup failed: {e}")
184
 
185
+ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
186
+ status = gr.Markdown("") # show init status/errors
 
 
187
 
188
  with gr.Row():
189
  prompt = gr.Textbox(label="Prompt", lines=3)
 
206
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
207
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
208
 
209
+ btn = gr.Button("Generate", variant="primary", interactive=False) # locked until model loads
210
  gallery = gr.Gallery(columns=4, height=420)
211
 
 
212
  def _startup():
213
  bootstrap_model()
214
+ if INIT_ERROR:
215
+ return gr.Markdown.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.CheckboxGroup.update(choices=[]), gr.Button.update(interactive=False)
216
+ msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
217
+ return gr.Markdown.update(value=msg), gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys())), gr.Button.update(interactive=True)
218
 
219
+ demo.load(_startup, outputs=[status, lora_names, btn])
220
 
 
221
  if DO_WARMUP:
222
  demo.load(lambda: warmup(), inputs=None, outputs=None)
223
 
 
224
  btn.click(
225
  txt2img,
226
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
 
230
  concurrency_id="gpu_queue",
231
  )
232
 
 
233
  demo.queue(max_size=32, default_concurrency_limit=1).launch()