DB2169 commited on
Commit
2ccc007
·
verified ·
1 Parent(s): e2271b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -53
app.py CHANGED
@@ -6,8 +6,8 @@ import gradio as gr
6
  import spaces
7
  from huggingface_hub import snapshot_download
8
  from diffusers import (
9
- StableDiffusionXLPipeline,
10
- StableDiffusionPipeline,
11
  DPMSolverMultistepScheduler,
12
  EulerAncestralDiscreteScheduler,
13
  EulerDiscreteScheduler,
@@ -16,16 +16,13 @@ from diffusers import (
16
  PNDMScheduler,
17
  )
18
 
19
- # ----------------- Config (set in Space Secrets if private) -----------------
20
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/Mixy").strip()
21
  CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "unknown.safetensors").strip()
22
  HF_TOKEN = os.getenv("HF_TOKEN", None)
23
- DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call
24
 
25
- # Optional override: JSON string for LoRA manifest (same shape as loras.json)
26
  LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
27
-
28
- # Where snapshot_download caches the repo in the container
29
  REPO_DIR = "/home/user/model"
30
 
31
  SCHEDULERS = {
@@ -38,22 +35,14 @@ SCHEDULERS = {
38
  "dpmpp_2m": DPMSolverMultistepScheduler,
39
  }
40
 
41
- # Globals populated at startup
42
  pipe = None
43
- IS_SDXL = True
44
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
45
  INIT_ERROR: Optional[str] = None
46
 
47
- # ----------------- Helpers -----------------
48
  def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
49
- """
50
- Manifest load order:
51
- 1) Environment variable LORAS_JSON (if provided)
52
- 2) loras.json inside the downloaded model repo
53
- 3) loras.json at the Space root (next to app.py)
54
- 4) Built-in fallback with MoriiMee_Gothic you provided
55
- """
56
- # 1) From env JSON
57
  if LORAS_JSON:
58
  try:
59
  parsed = json.loads(LORAS_JSON)
@@ -62,7 +51,6 @@ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
62
  except Exception as e:
63
  print(f"[WARN] Failed to parse LORAS_JSON: {e}")
64
 
65
- # 2) From repo
66
  repo_manifest = os.path.join(repo_dir, "loras.json")
67
  if os.path.exists(repo_manifest):
68
  try:
@@ -73,7 +61,6 @@ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
73
  except Exception as e:
74
  print(f"[WARN] Failed to parse repo loras.json: {e}")
75
 
76
- # 3) From Space root
77
  local_manifest = os.path.join(os.getcwd(), "loras.json")
78
  if os.path.exists(local_manifest):
79
  try:
@@ -84,7 +71,6 @@ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
84
  except Exception as e:
85
  print(f"[WARN] Failed to parse local loras.json: {e}")
86
 
87
- # 4) Built-in fallback: your MoriiMee Gothic LoRA
88
  print("[INFO] Using built-in LoRA fallback manifest.")
89
  return {
90
  "MoriiMee_Gothic": {
@@ -93,11 +79,11 @@ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
93
  }
94
  }
95
 
96
- # ----------------- Bootstrap (download + load on CPU) -----------------
97
  def bootstrap_model():
98
  """
99
- Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint,
100
- keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls.
101
  """
102
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
103
  INIT_ERROR = None
@@ -125,20 +111,24 @@ def bootstrap_model():
125
  print(f"[ERROR] {INIT_ERROR}")
126
  return
127
 
 
 
 
 
128
  try:
129
- # Attempt SDXL first (text_encoder_2 present)
130
- _pipe = StableDiffusionXLPipeline.from_single_file(
131
- ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
132
  )
133
- sdxl = True
134
- except Exception:
 
135
  try:
136
- _pipe = StableDiffusionPipeline.from_single_file(
137
- ckpt_path, torch_dtype=torch.float16, use_safetensors=True
138
  )
139
- sdxl = False
140
- except Exception as e:
141
- INIT_ERROR = f"Failed to load pipeline: {e}"
142
  print(f"[ERROR] {INIT_ERROR}")
143
  return
144
 
@@ -152,9 +142,8 @@ def bootstrap_model():
152
  manifest = load_lora_manifest(local_dir)
153
  print(f"[INFO] LoRAs available: {list(manifest.keys())}")
154
 
155
- # Publish
156
  pipe = _pipe
157
- IS_SDXL = sdxl
158
  LORA_MANIFEST = manifest
159
 
160
  def apply_loras(selected: List[str], scale: float, repo_dir: str):
@@ -179,7 +168,7 @@ def apply_loras(selected: List[str], scale: float, repo_dir: str):
179
  except Exception as e:
180
  print(f"[WARN] set_adapters failed: {e}")
181
 
182
- # ----------------- Generation (ZeroGPU) -----------------
183
  @spaces.GPU
184
  def txt2img(
185
  prompt: str,
@@ -230,14 +219,8 @@ def txt2img(
230
  out = pipe(**kwargs)
231
  return out.images
232
 
233
- def warmup():
234
- try:
235
- _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
236
- except Exception as e:
237
- print(f"[WARN] Warmup failed: {e}")
238
-
239
- # ----------------- UI -----------------
240
- with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
241
  status = gr.Markdown("")
242
 
243
  with gr.Row():
@@ -267,15 +250,30 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
267
  def _startup():
268
  bootstrap_model()
269
  if INIT_ERROR:
270
- return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False)
 
 
 
 
 
 
 
271
  msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
272
- # Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback)
273
- return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True)
274
-
275
- demo.load(_startup, outputs=[status, lora_names, btn])
 
 
 
 
 
 
 
 
 
276
 
277
- if DO_WARMUP:
278
- demo.load(lambda: warmup(), inputs=None, outputs=None)
279
 
280
  btn.click(
281
  txt2img,
 
6
  import spaces
7
  from huggingface_hub import snapshot_download
8
  from diffusers import (
9
+ StableDiffusionPipeline, # SD 1.x/2.x single-file loader
10
+ StableDiffusionXLPipeline, # SDXL single-file loader
11
  DPMSolverMultistepScheduler,
12
  EulerAncestralDiscreteScheduler,
13
  EulerDiscreteScheduler,
 
16
  PNDMScheduler,
17
  )
18
 
19
+ # -------- Config --------
20
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/mixy").strip()
21
  CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "unknown.safetensors").strip()
22
  HF_TOKEN = os.getenv("HF_TOKEN", None)
23
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1"
24
 
 
25
  LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
 
 
26
  REPO_DIR = "/home/user/model"
27
 
28
  SCHEDULERS = {
 
35
  "dpmpp_2m": DPMSolverMultistepScheduler,
36
  }
37
 
38
+ # -------- Globals --------
39
  pipe = None
40
+ IS_SDXL = False
41
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
42
  INIT_ERROR: Optional[str] = None
43
 
44
+ # -------- Helpers --------
45
  def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
 
 
 
 
 
 
 
 
46
  if LORAS_JSON:
47
  try:
48
  parsed = json.loads(LORAS_JSON)
 
51
  except Exception as e:
52
  print(f"[WARN] Failed to parse LORAS_JSON: {e}")
53
 
 
54
  repo_manifest = os.path.join(repo_dir, "loras.json")
55
  if os.path.exists(repo_manifest):
56
  try:
 
61
  except Exception as e:
62
  print(f"[WARN] Failed to parse repo loras.json: {e}")
63
 
 
64
  local_manifest = os.path.join(os.getcwd(), "loras.json")
65
  if os.path.exists(local_manifest):
66
  try:
 
71
  except Exception as e:
72
  print(f"[WARN] Failed to parse local loras.json: {e}")
73
 
 
74
  print("[INFO] Using built-in LoRA fallback manifest.")
75
  return {
76
  "MoriiMee_Gothic": {
 
79
  }
80
  }
81
 
82
+ # -------- Bootstrap (CPU) --------
83
  def bootstrap_model():
84
  """
85
+ Try SD (1.x/2.x) single-file first, then SDXL single-file, to maximize compatibility
86
+ with older diffusers that don’t expose DiffusionPipeline.from_single_file.
87
  """
88
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
89
  INIT_ERROR = None
 
111
  print(f"[ERROR] {INIT_ERROR}")
112
  return
113
 
114
+ _pipe = None
115
+ _is_sdxl = False
116
+
117
+ # 1) SD 1.x/2.x first (most single-file merges are SD), then SDXL
118
  try:
119
+ _pipe = StableDiffusionPipeline.from_single_file(
120
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True
 
121
  )
122
+ _is_sdxl = False
123
+ except Exception as e_sd:
124
+ print(f"[INFO] SD load failed or not SD: {e_sd}")
125
  try:
126
+ _pipe = StableDiffusionXLPipeline.from_single_file(
127
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
128
  )
129
+ _is_sdxl = True
130
+ except Exception as e_sdxl:
131
+ INIT_ERROR = f"Failed to load pipeline (SD and SDXL): SD={e_sd} | SDXL={e_sdxl}"
132
  print(f"[ERROR] {INIT_ERROR}")
133
  return
134
 
 
142
  manifest = load_lora_manifest(local_dir)
143
  print(f"[INFO] LoRAs available: {list(manifest.keys())}")
144
 
 
145
  pipe = _pipe
146
+ IS_SDXL = _is_sdxl
147
  LORA_MANIFEST = manifest
148
 
149
  def apply_loras(selected: List[str], scale: float, repo_dir: str):
 
168
  except Exception as e:
169
  print(f"[WARN] set_adapters failed: {e}")
170
 
171
+ # -------- Generation (ZeroGPU) --------
172
  @spaces.GPU
173
  def txt2img(
174
  prompt: str,
 
219
  out = pipe(**kwargs)
220
  return out.images
221
 
222
+ # -------- UI --------
223
+ with gr.Blocks(title="SDXL/SD single-file (ZeroGPU, LoRA-ready)") as demo:
 
 
 
 
 
 
224
  status = gr.Markdown("")
225
 
226
  with gr.Row():
 
250
  def _startup():
251
  bootstrap_model()
252
  if INIT_ERROR:
253
+ return (
254
+ gr.update(value=f"❌ Init failed: {INIT_ERROR}"),
255
+ gr.update(choices=[]),
256
+ gr.update(value=1024, minimum=256, maximum=1536, step=64),
257
+ gr.update(value=1024, minimum=256, maximum=1536, step=64),
258
+ gr.update(interactive=False),
259
+ )
260
+ default_wh = 1024 if IS_SDXL else 512
261
  msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
262
+ # Warm up only after model is ready (avoids race)
263
+ if DO_WARMUP:
264
+ try:
265
+ _ = txt2img("warmup", "", default_wh, default_wh, 4, 4.0, 1, 1234, "default", [], 0.0, False)
266
+ except Exception as e:
267
+ print(f"[WARN] Warmup failed: {e}")
268
+ return (
269
+ gr.update(value=msg),
270
+ gr.update(choices=list(LORA_MANIFEST.keys())),
271
+ gr.update(value=default_wh, minimum=256, maximum=1536, step=64),
272
+ gr.update(value=default_wh, minimum=256, maximum=1536, step=64),
273
+ gr.update(interactive=True),
274
+ )
275
 
276
+ demo.load(_startup, outputs=[status, lora_names, width, height, btn])
 
277
 
278
  btn.click(
279
  txt2img,