|
|
import os, json |
|
|
from typing import List, Dict, Any, Optional |
|
|
from PIL import Image |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from huggingface_hub import snapshot_download |
|
|
from diffusers import ( |
|
|
StableDiffusionXLPipeline, |
|
|
StableDiffusionPipeline, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerAncestralDiscreteScheduler, |
|
|
EulerDiscreteScheduler, |
|
|
DDIMScheduler, |
|
|
LMSDiscreteScheduler, |
|
|
PNDMScheduler, |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip() |
|
|
CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip() |
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None) |
|
|
DO_WARMUP = os.getenv("WARMUP", "1") == "1" |
|
|
|
|
|
|
|
|
LORAS_JSON = os.getenv("LORAS_JSON", "").strip() |
|
|
|
|
|
|
|
|
REPO_DIR = "/home/user/model" |
|
|
|
|
|
SCHEDULERS = { |
|
|
"default": None, |
|
|
"euler_a": EulerAncestralDiscreteScheduler, |
|
|
"euler": EulerDiscreteScheduler, |
|
|
"ddim": DDIMScheduler, |
|
|
"lms": LMSDiscreteScheduler, |
|
|
"pndm": PNDMScheduler, |
|
|
"dpmpp_2m": DPMSolverMultistepScheduler, |
|
|
} |
|
|
|
|
|
|
|
|
pipe = None |
|
|
IS_SDXL = True |
|
|
LORA_MANIFEST: Dict[str, Dict[str, str]] = {} |
|
|
INIT_ERROR: Optional[str] = None |
|
|
|
|
|
|
|
|
def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]: |
|
|
""" |
|
|
Manifest load order: |
|
|
1) Environment variable LORAS_JSON (if provided) |
|
|
2) loras.json inside the downloaded model repo |
|
|
3) loras.json at the Space root (next to app.py) |
|
|
4) Built-in fallback with MoriiMee_Gothic you provided |
|
|
""" |
|
|
|
|
|
if LORAS_JSON: |
|
|
try: |
|
|
parsed = json.loads(LORAS_JSON) |
|
|
if isinstance(parsed, dict): |
|
|
return parsed |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to parse LORAS_JSON: {e}") |
|
|
|
|
|
|
|
|
repo_manifest = os.path.join(repo_dir, "loras.json") |
|
|
if os.path.exists(repo_manifest): |
|
|
try: |
|
|
with open(repo_manifest, "r", encoding="utf-8") as f: |
|
|
parsed = json.load(f) |
|
|
if isinstance(parsed, dict): |
|
|
return parsed |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to parse repo loras.json: {e}") |
|
|
|
|
|
|
|
|
local_manifest = os.path.join(os.getcwd(), "loras.json") |
|
|
if os.path.exists(local_manifest): |
|
|
try: |
|
|
with open(local_manifest, "r", encoding="utf-8") as f: |
|
|
parsed = json.load(f) |
|
|
if isinstance(parsed, dict): |
|
|
return parsed |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to parse local loras.json: {e}") |
|
|
|
|
|
|
|
|
print("[INFO] Using built-in LoRA fallback manifest.") |
|
|
return { |
|
|
"MoriiMee_Gothic": { |
|
|
"repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1", |
|
|
"weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def bootstrap_model(): |
|
|
""" |
|
|
Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint, |
|
|
keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls. |
|
|
""" |
|
|
global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR |
|
|
INIT_ERROR = None |
|
|
|
|
|
if not MODEL_REPO_ID or not CHECKPOINT_FILENAME: |
|
|
INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME." |
|
|
print(f"[ERROR] {INIT_ERROR}") |
|
|
return |
|
|
|
|
|
try: |
|
|
local_dir = snapshot_download( |
|
|
repo_id=MODEL_REPO_ID, |
|
|
token=HF_TOKEN, |
|
|
local_dir=REPO_DIR, |
|
|
ignore_patterns=["*.md"], |
|
|
) |
|
|
except Exception as e: |
|
|
INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}" |
|
|
print(f"[ERROR] {INIT_ERROR}") |
|
|
return |
|
|
|
|
|
ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME) |
|
|
if not os.path.exists(ckpt_path): |
|
|
INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME." |
|
|
print(f"[ERROR] {INIT_ERROR}") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
_pipe = StableDiffusionXLPipeline.from_single_file( |
|
|
ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False |
|
|
) |
|
|
sdxl = True |
|
|
except Exception: |
|
|
try: |
|
|
_pipe = StableDiffusionPipeline.from_single_file( |
|
|
ckpt_path, torch_dtype=torch.float16, use_safetensors=True |
|
|
) |
|
|
sdxl = False |
|
|
except Exception as e: |
|
|
INIT_ERROR = f"Failed to load pipeline: {e}" |
|
|
print(f"[ERROR] {INIT_ERROR}") |
|
|
return |
|
|
|
|
|
if hasattr(_pipe, "enable_attention_slicing"): |
|
|
_pipe.enable_attention_slicing("max") |
|
|
if hasattr(_pipe, "enable_vae_slicing"): |
|
|
_pipe.enable_vae_slicing() |
|
|
if hasattr(_pipe, "set_progress_bar_config"): |
|
|
_pipe.set_progress_bar_config(disable=True) |
|
|
|
|
|
manifest = load_lora_manifest(local_dir) |
|
|
print(f"[INFO] LoRAs available: {list(manifest.keys())}") |
|
|
|
|
|
|
|
|
pipe = _pipe |
|
|
IS_SDXL = sdxl |
|
|
LORA_MANIFEST = manifest |
|
|
|
|
|
def apply_loras(selected: List[str], scale: float, repo_dir: str): |
|
|
if not selected or scale <= 0: |
|
|
return |
|
|
for name in selected: |
|
|
meta = LORA_MANIFEST.get(name) |
|
|
if not meta: |
|
|
print(f"[WARN] Requested LoRA '{name}' not in manifest.") |
|
|
continue |
|
|
try: |
|
|
if "path" in meta: |
|
|
pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name) |
|
|
else: |
|
|
pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name) |
|
|
print(f"[INFO] Loaded LoRA: {name}") |
|
|
except Exception as e: |
|
|
print(f"[WARN] LoRA load failed for {name}: {e}") |
|
|
try: |
|
|
pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected)) |
|
|
print(f"[INFO] Activated LoRAs: {selected} at scale {scale}") |
|
|
except Exception as e: |
|
|
print(f"[WARN] set_adapters failed: {e}") |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def txt2img( |
|
|
prompt: str, |
|
|
negative: str, |
|
|
width: int, |
|
|
height: int, |
|
|
steps: int, |
|
|
guidance: float, |
|
|
images: int, |
|
|
seed: Optional[int], |
|
|
scheduler: str, |
|
|
loras: List[str], |
|
|
lora_scale: float, |
|
|
fuse_lora: bool, |
|
|
): |
|
|
if pipe is None: |
|
|
raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}") |
|
|
|
|
|
local_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
pipe.to(local_device) |
|
|
|
|
|
if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None: |
|
|
try: |
|
|
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) |
|
|
except Exception as e: |
|
|
print(f"[WARN] Scheduler switch failed: {e}") |
|
|
|
|
|
apply_loras(loras, lora_scale, REPO_DIR) |
|
|
if fuse_lora and loras: |
|
|
try: |
|
|
pipe.fuse_lora(lora_scale=float(lora_scale)) |
|
|
except Exception as e: |
|
|
print(f"[WARN] fuse_lora failed: {e}") |
|
|
|
|
|
generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None |
|
|
|
|
|
kwargs: Dict[str, Any] = dict( |
|
|
prompt=prompt or "", |
|
|
negative_prompt=negative or None, |
|
|
width=int(width), |
|
|
height=int(height), |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=float(guidance), |
|
|
num_images_per_prompt=int(images), |
|
|
generator=generator, |
|
|
) |
|
|
with torch.inference_mode(): |
|
|
out = pipe(**kwargs) |
|
|
return out.images |
|
|
|
|
|
def warmup(): |
|
|
try: |
|
|
_ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False) |
|
|
except Exception as e: |
|
|
print(f"[WARN] Warmup failed: {e}") |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo: |
|
|
status = gr.Markdown("") |
|
|
|
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox(label="Prompt", lines=3) |
|
|
negative = gr.Textbox(label="Negative Prompt", lines=3) |
|
|
|
|
|
with gr.Row(): |
|
|
width = gr.Slider(256, 1536, 1024, step=64, label="Width") |
|
|
height = gr.Slider(256, 1536, 1024, step=64, label="Height") |
|
|
|
|
|
with gr.Row(): |
|
|
steps = gr.Slider(5, 80, 30, step=1, label="Steps") |
|
|
guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance") |
|
|
images = gr.Slider(1, 4, 1, step=1, label="Images") |
|
|
|
|
|
with gr.Row(): |
|
|
seed = gr.Number(value=None, precision=0, label="Seed (blank=random)") |
|
|
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler") |
|
|
|
|
|
lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)") |
|
|
lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale") |
|
|
fuse = gr.Checkbox(label="Fuse LoRA (faster after load)") |
|
|
|
|
|
btn = gr.Button("Generate", variant="primary", interactive=False) |
|
|
gallery = gr.Gallery(columns=4, height=420) |
|
|
|
|
|
def _startup(): |
|
|
bootstrap_model() |
|
|
if INIT_ERROR: |
|
|
return gr.update(value=f"β Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False) |
|
|
msg = f"β
Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})" |
|
|
|
|
|
return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True) |
|
|
|
|
|
demo.load(_startup, outputs=[status, lora_names, btn]) |
|
|
|
|
|
if DO_WARMUP: |
|
|
demo.load(lambda: warmup(), inputs=None, outputs=None) |
|
|
|
|
|
btn.click( |
|
|
txt2img, |
|
|
inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse], |
|
|
outputs=[gallery], |
|
|
api_name="txt2img", |
|
|
concurrency_limit=1, |
|
|
concurrency_id="gpu_queue", |
|
|
) |
|
|
|
|
|
demo.queue(max_size=32, default_concurrency_limit=1).launch() |
|
|
|