DB2169 commited on
Commit
974b6ea
·
verified ·
1 Parent(s): 043665d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ from typing import List, Dict, Any, Optional
3
+ from PIL import Image
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+ from huggingface_hub import snapshot_download
8
+ from diffusers import (
9
+ StableDiffusionXLPipeline,
10
+ StableDiffusionPipeline,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ DDIMScheduler,
15
+ LMSDiscreteScheduler,
16
+ PNDMScheduler,
17
+ )
18
+
19
+ # ----------------- Config (set in Space Secrets if private) -----------------
20
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/Mixy").strip()
21
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "lustifySDXLNSFW_ggwpV7.safetensors").strip()
22
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
23
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call
24
+
25
+ # Optional override: JSON string for LoRA manifest (same shape as loras.json)
26
+ LORAS_JSON = os.getenv("LORAS_JSON", "").strip()
27
+
28
+ # Where snapshot_download caches the repo in the container
29
+ REPO_DIR = "/home/user/model"
30
+
31
+ SCHEDULERS = {
32
+ "default": None,
33
+ "euler_a": EulerAncestralDiscreteScheduler,
34
+ "euler": EulerDiscreteScheduler,
35
+ "ddim": DDIMScheduler,
36
+ "lms": LMSDiscreteScheduler,
37
+ "pndm": PNDMScheduler,
38
+ "dpmpp_2m": DPMSolverMultistepScheduler,
39
+ }
40
+
41
+ # Globals populated at startup
42
+ pipe = None
43
+ IS_SDXL = True
44
+ LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
45
+ INIT_ERROR: Optional[str] = None
46
+
47
+ # ----------------- Helpers -----------------
48
+ def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]:
49
+ """
50
+ Manifest load order:
51
+ 1) Environment variable LORAS_JSON (if provided)
52
+ 2) loras.json inside the downloaded model repo
53
+ 3) loras.json at the Space root (next to app.py)
54
+ 4) Built-in fallback with MoriiMee_Gothic you provided
55
+ """
56
+ # 1) From env JSON
57
+ if LORAS_JSON:
58
+ try:
59
+ parsed = json.loads(LORAS_JSON)
60
+ if isinstance(parsed, dict):
61
+ return parsed
62
+ except Exception as e:
63
+ print(f"[WARN] Failed to parse LORAS_JSON: {e}")
64
+
65
+ # 2) From repo
66
+ repo_manifest = os.path.join(repo_dir, "loras.json")
67
+ if os.path.exists(repo_manifest):
68
+ try:
69
+ with open(repo_manifest, "r", encoding="utf-8") as f:
70
+ parsed = json.load(f)
71
+ if isinstance(parsed, dict):
72
+ return parsed
73
+ except Exception as e:
74
+ print(f"[WARN] Failed to parse repo loras.json: {e}")
75
+
76
+ # 3) From Space root
77
+ local_manifest = os.path.join(os.getcwd(), "loras.json")
78
+ if os.path.exists(local_manifest):
79
+ try:
80
+ with open(local_manifest, "r", encoding="utf-8") as f:
81
+ parsed = json.load(f)
82
+ if isinstance(parsed, dict):
83
+ return parsed
84
+ except Exception as e:
85
+ print(f"[WARN] Failed to parse local loras.json: {e}")
86
+
87
+ # 4) Built-in fallback: your MoriiMee Gothic LoRA
88
+ print("[INFO] Using built-in LoRA fallback manifest.")
89
+ return {
90
+ "MoriiMee_Gothic": {
91
+ "repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1",
92
+ "weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors"
93
+ }
94
+ }
95
+
96
+ # ----------------- Bootstrap (download + load on CPU) -----------------
97
+ def bootstrap_model():
98
+ """
99
+ Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint,
100
+ keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls.
101
+ """
102
+ global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
103
+ INIT_ERROR = None
104
+
105
+ if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
106
+ INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME."
107
+ print(f"[ERROR] {INIT_ERROR}")
108
+ return
109
+
110
+ try:
111
+ local_dir = snapshot_download(
112
+ repo_id=MODEL_REPO_ID,
113
+ token=HF_TOKEN,
114
+ local_dir=REPO_DIR,
115
+ ignore_patterns=["*.md"],
116
+ )
117
+ except Exception as e:
118
+ INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
119
+ print(f"[ERROR] {INIT_ERROR}")
120
+ return
121
+
122
+ ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
123
+ if not os.path.exists(ckpt_path):
124
+ INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME."
125
+ print(f"[ERROR] {INIT_ERROR}")
126
+ return
127
+
128
+ try:
129
+ # Attempt SDXL first (text_encoder_2 present)
130
+ _pipe = StableDiffusionXLPipeline.from_single_file(
131
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
132
+ )
133
+ sdxl = True
134
+ except Exception:
135
+ try:
136
+ _pipe = StableDiffusionPipeline.from_single_file(
137
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True
138
+ )
139
+ sdxl = False
140
+ except Exception as e:
141
+ INIT_ERROR = f"Failed to load pipeline: {e}"
142
+ print(f"[ERROR] {INIT_ERROR}")
143
+ return
144
+
145
+ if hasattr(_pipe, "enable_attention_slicing"):
146
+ _pipe.enable_attention_slicing("max")
147
+ if hasattr(_pipe, "enable_vae_slicing"):
148
+ _pipe.enable_vae_slicing()
149
+ if hasattr(_pipe, "set_progress_bar_config"):
150
+ _pipe.set_progress_bar_config(disable=True)
151
+
152
+ manifest = load_lora_manifest(local_dir)
153
+ print(f"[INFO] LoRAs available: {list(manifest.keys())}")
154
+
155
+ # Publish
156
+ pipe = _pipe
157
+ IS_SDXL = sdxl
158
+ LORA_MANIFEST = manifest
159
+
160
+ def apply_loras(selected: List[str], scale: float, repo_dir: str):
161
+ if not selected or scale <= 0:
162
+ return
163
+ for name in selected:
164
+ meta = LORA_MANIFEST.get(name)
165
+ if not meta:
166
+ print(f"[WARN] Requested LoRA '{name}' not in manifest.")
167
+ continue
168
+ try:
169
+ if "path" in meta:
170
+ pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
171
+ else:
172
+ pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
173
+ print(f"[INFO] Loaded LoRA: {name}")
174
+ except Exception as e:
175
+ print(f"[WARN] LoRA load failed for {name}: {e}")
176
+ try:
177
+ pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected))
178
+ print(f"[INFO] Activated LoRAs: {selected} at scale {scale}")
179
+ except Exception as e:
180
+ print(f"[WARN] set_adapters failed: {e}")
181
+
182
+ # ----------------- Generation (ZeroGPU) -----------------
183
+ @spaces.GPU
184
+ def txt2img(
185
+ prompt: str,
186
+ negative: str,
187
+ width: int,
188
+ height: int,
189
+ steps: int,
190
+ guidance: float,
191
+ images: int,
192
+ seed: Optional[int],
193
+ scheduler: str,
194
+ loras: List[str],
195
+ lora_scale: float,
196
+ fuse_lora: bool,
197
+ ):
198
+ if pipe is None:
199
+ raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
200
+
201
+ local_device = "cuda" if torch.cuda.is_available() else "cpu"
202
+ pipe.to(local_device)
203
+
204
+ if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
205
+ try:
206
+ pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
207
+ except Exception as e:
208
+ print(f"[WARN] Scheduler switch failed: {e}")
209
+
210
+ apply_loras(loras, lora_scale, REPO_DIR)
211
+ if fuse_lora and loras:
212
+ try:
213
+ pipe.fuse_lora(lora_scale=float(lora_scale))
214
+ except Exception as e:
215
+ print(f"[WARN] fuse_lora failed: {e}")
216
+
217
+ generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
218
+
219
+ kwargs: Dict[str, Any] = dict(
220
+ prompt=prompt or "",
221
+ negative_prompt=negative or None,
222
+ width=int(width),
223
+ height=int(height),
224
+ num_inference_steps=int(steps),
225
+ guidance_scale=float(guidance),
226
+ num_images_per_prompt=int(images),
227
+ generator=generator,
228
+ )
229
+ with torch.inference_mode():
230
+ out = pipe(**kwargs)
231
+ return out.images
232
+
233
+ def warmup():
234
+ try:
235
+ _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
236
+ except Exception as e:
237
+ print(f"[WARN] Warmup failed: {e}")
238
+
239
+ # ----------------- UI -----------------
240
+ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
241
+ status = gr.Markdown("")
242
+
243
+ with gr.Row():
244
+ prompt = gr.Textbox(label="Prompt", lines=3)
245
+ negative = gr.Textbox(label="Negative Prompt", lines=3)
246
+
247
+ with gr.Row():
248
+ width = gr.Slider(256, 1536, 1024, step=64, label="Width")
249
+ height = gr.Slider(256, 1536, 1024, step=64, label="Height")
250
+
251
+ with gr.Row():
252
+ steps = gr.Slider(5, 80, 30, step=1, label="Steps")
253
+ guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
254
+ images = gr.Slider(1, 4, 1, step=1, label="Images")
255
+
256
+ with gr.Row():
257
+ seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
258
+ scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
259
+
260
+ lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
261
+ lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
262
+ fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
263
+
264
+ btn = gr.Button("Generate", variant="primary", interactive=False)
265
+ gallery = gr.Gallery(columns=4, height=420)
266
+
267
+ def _startup():
268
+ bootstrap_model()
269
+ if INIT_ERROR:
270
+ return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False)
271
+ msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})"
272
+ # Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback)
273
+ return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True)
274
+
275
+ demo.load(_startup, outputs=[status, lora_names, btn])
276
+
277
+ if DO_WARMUP:
278
+ demo.load(lambda: warmup(), inputs=None, outputs=None)
279
+
280
+ btn.click(
281
+ txt2img,
282
+ inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
283
+ outputs=[gallery],
284
+ api_name="txt2img",
285
+ concurrency_limit=1,
286
+ concurrency_id="gpu_queue",
287
+ )
288
+
289
+ demo.queue(max_size=32, default_concurrency_limit=1).launch()