DB2169 commited on
Commit
054efa5
·
verified ·
1 Parent(s): 3f5ac2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -39
app.py CHANGED
@@ -3,23 +3,8 @@ from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
- import spaces # ZeroGPU decorator
7
  from huggingface_hub import snapshot_download
8
-
9
- # ----------------- Config (set in Space Secrets if private) -----------------
10
- # Your private repo that contains the base .safetensors and loras.json
11
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip()
12
- # Exact filename of the base checkpoint inside the repo (case-sensitive)
13
- CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip()
14
- # Personal access token with read scope (required for private repos)
15
- HF_TOKEN = os.getenv("HF_TOKEN", None)
16
- # Toggle first-boot warmup (GPU-allocating on ZeroGPU)
17
- DO_WARMUP = os.getenv("WARMUP", "1") == "1"
18
-
19
- # Where snapshot_download will cache the repo
20
- REPO_DIR = "/home/user/model"
21
-
22
- # Supported schedulers
23
  from diffusers import (
24
  StableDiffusionXLPipeline,
25
  StableDiffusionPipeline,
@@ -30,6 +15,19 @@ from diffusers import (
30
  LMSDiscreteScheduler,
31
  PNDMScheduler,
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  SCHEDULERS = {
34
  "default": None,
35
  "euler_a": EulerAncestralDiscreteScheduler,
@@ -46,11 +44,60 @@ IS_SDXL = True
46
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
47
  INIT_ERROR: Optional[str] = None
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # ----------------- Bootstrap (download + load on CPU) -----------------
50
  def bootstrap_model():
51
  """
52
- Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint.
53
- Keeps pipeline on CPU; ZeroGPU attaches GPU inside the @spaces.GPU function.
54
  """
55
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
56
  INIT_ERROR = None
@@ -79,7 +126,7 @@ def bootstrap_model():
79
  return
80
 
81
  try:
82
- # Try SDXL first
83
  _pipe = StableDiffusionXLPipeline.from_single_file(
84
  ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
85
  )
@@ -91,11 +138,10 @@ def bootstrap_model():
91
  )
92
  sdxl = False
93
  except Exception as e:
94
- INIT_ERROR = f"Failed to load pipeline from {CHECKPOINT_FILENAME}: {e}"
95
  print(f"[ERROR] {INIT_ERROR}")
96
  return
97
 
98
- # Light memory/perf tweaks
99
  if hasattr(_pipe, "enable_attention_slicing"):
100
  _pipe.enable_attention_slicing("max")
101
  if hasattr(_pipe, "enable_vae_slicing"):
@@ -103,18 +149,10 @@ def bootstrap_model():
103
  if hasattr(_pipe, "set_progress_bar_config"):
104
  _pipe.set_progress_bar_config(disable=True)
105
 
106
- # Load LoRA manifest if present
107
- man_path = os.path.join(local_dir, "loras.json")
108
- manifest = {}
109
- if os.path.exists(man_path):
110
- try:
111
- with open(man_path, "r", encoding="utf-8") as f:
112
- manifest = json.load(f)
113
- except Exception as e:
114
- print(f"[WARN] Failed to parse loras.json: {e}")
115
 
116
- # Publish globals
117
- global pipe, IS_SDXL, LORA_MANIFEST
118
  pipe = _pipe
119
  IS_SDXL = sdxl
120
  LORA_MANIFEST = manifest
@@ -125,20 +163,23 @@ def apply_loras(selected: List[str], scale: float, repo_dir: str):
125
  for name in selected:
126
  meta = LORA_MANIFEST.get(name)
127
  if not meta:
 
128
  continue
129
  try:
130
  if "path" in meta:
131
  pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
132
  else:
133
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
 
134
  except Exception as e:
135
  print(f"[WARN] LoRA load failed for {name}: {e}")
136
  try:
137
  pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
 
138
  except Exception as e:
139
  print(f"[WARN] set_adapters failed: {e}")
140
 
141
- # ----------------- Generation (GPU-attached under ZeroGPU) -----------------
142
  @spaces.GPU
143
  def txt2img(
144
  prompt: str,
@@ -160,14 +201,12 @@ def txt2img(
160
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
161
  pipe.to(local_device)
162
 
163
- # Optional scheduler switch
164
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
165
  try:
166
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
167
  except Exception as e:
168
  print(f"[WARN] Scheduler switch failed: {e}")
169
 
170
- # Apply LoRAs
171
  apply_loras(loras, lora_scale, REPO_DIR)
172
  if fuse_lora and loras:
173
  try:
@@ -199,7 +238,7 @@ def warmup():
199
 
200
  # ----------------- UI -----------------
201
  with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
202
- status = gr.Markdown("") # shows init result or errors
203
 
204
  with gr.Row():
205
  prompt = gr.Textbox(label="Prompt", lines=3)
@@ -228,9 +267,10 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
228
  def _startup():
229
  bootstrap_model()
230
  if INIT_ERROR:
231
- return gr.Markdown.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.CheckboxGroup.update(choices=[]), gr.Button.update(interactive=False)
232
  msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
233
- return gr.Markdown.update(value=msg), gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys())), gr.Button.update(interactive=True)
 
234
 
235
  demo.load(_startup, outputs=[status, lora_names, btn])
236
 
@@ -246,5 +286,4 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
246
  concurrency_id="gpu_queue",
247
  )
248
 
249
- # Gradio 4.x queue config (no deprecated args)
250
  demo.queue(max_size=32, default_concurrency_limit=1).launch()
 
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
+ import spaces
7
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from diffusers import (
9
  StableDiffusionXLPipeline,
10
  StableDiffusionPipeline,
 
15
  LMSDiscreteScheduler,
16
  PNDMScheduler,
17
  )
18
+
19
+ # ----------------- Config (set in Space Secrets if private) -----------------
20
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip()
21
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip()
22
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
23
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call
24
+
25
+ # Optional override: JSON string for LoRA manifest (same shape as loras.json)
26
+ LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
27
+
28
+ # Where snapshot_download caches the repo in the container
29
+ REPO_DIR = "/home/user/model"
30
+
31
  SCHEDULERS = {
32
  "default": None,
33
  "euler_a": EulerAncestralDiscreteScheduler,
 
44
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
45
  INIT_ERROR: Optional[str] = None
46
 
47
+ # ----------------- Helpers -----------------
48
+ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
49
+ """
50
+ Manifest load order:
51
+ 1) Environment variable LORAS_JSON (if provided)
52
+ 2) loras.json inside the downloaded model repo
53
+ 3) loras.json at the Space root (next to app.py)
54
+ 4) Built-in fallback with MoriiMee_Gothic you provided
55
+ """
56
+ # 1) From env JSON
57
+ if LORAS_JSON:
58
+ try:
59
+ parsed = json.loads(LORAS_JSON)
60
+ if isinstance(parsed, dict):
61
+ return parsed
62
+ except Exception as e:
63
+ print(f"[WARN] Failed to parse LORAS_JSON: {e}")
64
+
65
+ # 2) From repo
66
+ repo_manifest = os.path.join(repo_dir, "loras.json")
67
+ if os.path.exists(repo_manifest):
68
+ try:
69
+ with open(repo_manifest, "r", encoding="utf-8") as f:
70
+ parsed = json.load(f)
71
+ if isinstance(parsed, dict):
72
+ return parsed
73
+ except Exception as e:
74
+ print(f"[WARN] Failed to parse repo loras.json: {e}")
75
+
76
+ # 3) From Space root
77
+ local_manifest = os.path.join(os.getcwd(), "loras.json")
78
+ if os.path.exists(local_manifest):
79
+ try:
80
+ with open(local_manifest, "r", encoding="utf-8") as f:
81
+ parsed = json.load(f)
82
+ if isinstance(parsed, dict):
83
+ return parsed
84
+ except Exception as e:
85
+ print(f"[WARN] Failed to parse local loras.json: {e}")
86
+
87
+ # 4) Built-in fallback: your MoriiMee Gothic LoRA
88
+ print("[INFO] Using built-in LoRA fallback manifest.")
89
+ return {
90
+ "MoriiMee_Gothic": {
91
+ "repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1",
92
+ "weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors"
93
+ }
94
+ }
95
+
96
  # ----------------- Bootstrap (download + load on CPU) -----------------
97
  def bootstrap_model():
98
  """
99
+ Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint,
100
+ keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls.
101
  """
102
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
103
  INIT_ERROR = None
 
126
  return
127
 
128
  try:
129
+ # Attempt SDXL first (text_encoder_2 present)
130
  _pipe = StableDiffusionXLPipeline.from_single_file(
131
  ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
132
  )
 
138
  )
139
  sdxl = False
140
  except Exception as e:
141
+ INIT_ERROR = f"Failed to load pipeline: {e}"
142
  print(f"[ERROR] {INIT_ERROR}")
143
  return
144
 
 
145
  if hasattr(_pipe, "enable_attention_slicing"):
146
  _pipe.enable_attention_slicing("max")
147
  if hasattr(_pipe, "enable_vae_slicing"):
 
149
  if hasattr(_pipe, "set_progress_bar_config"):
150
  _pipe.set_progress_bar_config(disable=True)
151
 
152
+ manifest = load_lora_manifest(local_dir)
153
+ print(f"[INFO] LoRAs available: {list(manifest.keys())}")
 
 
 
 
 
 
 
154
 
155
+ # Publish
 
156
  pipe = _pipe
157
  IS_SDXL = sdxl
158
  LORA_MANIFEST = manifest
 
163
  for name in selected:
164
  meta = LORA_MANIFEST.get(name)
165
  if not meta:
166
+ print(f"[WARN] Requested LoRA '{name}' not in manifest.")
167
  continue
168
  try:
169
  if "path" in meta:
170
  pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
171
  else:
172
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
173
+ print(f"[INFO] Loaded LoRA: {name}")
174
  except Exception as e:
175
  print(f"[WARN] LoRA load failed for {name}: {e}")
176
  try:
177
  pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
178
+ print(f"[INFO] Activated LoRAs: {selected} at scale {scale}")
179
  except Exception as e:
180
  print(f"[WARN] set_adapters failed: {e}")
181
 
182
+ # ----------------- Generation (ZeroGPU) -----------------
183
  @spaces.GPU
184
  def txt2img(
185
  prompt: str,
 
201
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
202
  pipe.to(local_device)
203
 
 
204
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
205
  try:
206
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
207
  except Exception as e:
208
  print(f"[WARN] Scheduler switch failed: {e}")
209
 
 
210
  apply_loras(loras, lora_scale, REPO_DIR)
211
  if fuse_lora and loras:
212
  try:
 
238
 
239
  # ----------------- UI -----------------
240
  with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
241
+ status = gr.Markdown("")
242
 
243
  with gr.Row():
244
  prompt = gr.Textbox(label="Prompt", lines=3)
 
267
  def _startup():
268
  bootstrap_model()
269
  if INIT_ERROR:
270
+ return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False)
271
  msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
272
+ # Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback)
273
+ return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True)
274
 
275
  demo.load(_startup, outputs=[status, lora_names, btn])
276
 
 
286
  concurrency_id="gpu_queue",
287
  )
288
 
 
289
  demo.queue(max_size=32, default_concurrency_limit=1).launch()