Spaces:
Running on Zero
Running on Zero
Upload folder using huggingface_hub
Browse files- app.py +15 -13
- requirements.txt +1 -0
- synthcxr/constants.py +3 -3
app.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import os
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
import spaces
|
|
@@ -37,17 +39,18 @@ CONDITION_CHOICES = [
|
|
| 37 |
SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
-
# Pipeline (
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
-
_pipe = None
|
| 43 |
|
| 44 |
|
| 45 |
-
def
|
| 46 |
-
"""Load the
|
| 47 |
-
global _pipe
|
| 48 |
-
if _pipe is not None:
|
| 49 |
-
return _pipe
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
from synthcxr.pipeline import load_lora_weights, load_pipeline
|
| 52 |
|
| 53 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -58,26 +61,25 @@ def get_pipeline():
|
|
| 58 |
vram_limit = float(vram_limit_str) if vram_limit_str else None
|
| 59 |
|
| 60 |
print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
|
| 61 |
-
|
| 62 |
|
| 63 |
# LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
|
| 64 |
lora_epoch = os.environ.get("LORA_EPOCH", "2")
|
| 65 |
lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
|
| 66 |
|
| 67 |
if not lora.exists():
|
| 68 |
-
# Try step-based checkpoints or any available .safetensors
|
| 69 |
candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
|
| 70 |
if candidates:
|
| 71 |
lora = candidates[-1]
|
| 72 |
print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
|
| 73 |
else:
|
| 74 |
print("[WARN] No LoRA checkpoint found – running base model only.")
|
| 75 |
-
return
|
| 76 |
|
| 77 |
print(f"[INFO] Loading LoRA from {lora}")
|
| 78 |
-
load_lora_weights(
|
| 79 |
print("[INFO] Pipeline ready.")
|
| 80 |
-
return
|
| 81 |
|
| 82 |
|
| 83 |
# ---------------------------------------------------------------------------
|
|
@@ -164,7 +166,7 @@ def generate_cxr(
|
|
| 164 |
if mask_image is None:
|
| 165 |
raise gr.Error("Please select or upload a mask first.")
|
| 166 |
|
| 167 |
-
pipe =
|
| 168 |
if pipe is None:
|
| 169 |
raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
|
| 170 |
|
|
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import os
|
| 7 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 8 |
+
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
import spaces
|
|
|
|
| 39 |
SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
|
| 40 |
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
+
# Pipeline loading (fresh on each @spaces.GPU call; model files cached on disk)
|
| 43 |
# ---------------------------------------------------------------------------
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
+
def load_fresh_pipeline():
|
| 47 |
+
"""Load the pipeline + LoRA onto the *currently allocated* GPU.
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
ZeroGPU deallocates GPU memory after each ``@spaces.GPU`` call, so we
|
| 50 |
+
cannot cache tensors between calls. However, diffsynth caches the
|
| 51 |
+
model files on disk (HF Hub cache), so only tensor loading happens
|
| 52 |
+
here — not a full download.
|
| 53 |
+
"""
|
| 54 |
from synthcxr.pipeline import load_lora_weights, load_pipeline
|
| 55 |
|
| 56 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 61 |
vram_limit = float(vram_limit_str) if vram_limit_str else None
|
| 62 |
|
| 63 |
print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
|
| 64 |
+
pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
|
| 65 |
|
| 66 |
# LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
|
| 67 |
lora_epoch = os.environ.get("LORA_EPOCH", "2")
|
| 68 |
lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
|
| 69 |
|
| 70 |
if not lora.exists():
|
|
|
|
| 71 |
candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
|
| 72 |
if candidates:
|
| 73 |
lora = candidates[-1]
|
| 74 |
print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
|
| 75 |
else:
|
| 76 |
print("[WARN] No LoRA checkpoint found – running base model only.")
|
| 77 |
+
return pipe
|
| 78 |
|
| 79 |
print(f"[INFO] Loading LoRA from {lora}")
|
| 80 |
+
load_lora_weights(pipe, lora)
|
| 81 |
print("[INFO] Pipeline ready.")
|
| 82 |
+
return pipe
|
| 83 |
|
| 84 |
|
| 85 |
# ---------------------------------------------------------------------------
|
|
|
|
| 166 |
if mask_image is None:
|
| 167 |
raise gr.Error("Please select or upload a mask first.")
|
| 168 |
|
| 169 |
+
pipe = load_fresh_pipeline()
|
| 170 |
if pipe is None:
|
| 171 |
raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
|
| 172 |
|
requirements.txt
CHANGED
|
@@ -5,3 +5,4 @@ scipy
|
|
| 5 |
Pillow
|
| 6 |
numpy
|
| 7 |
torch
|
|
|
|
|
|
| 5 |
Pillow
|
| 6 |
numpy
|
| 7 |
torch
|
| 8 |
+
hf_transfer
|
synthcxr/constants.py
CHANGED
|
@@ -49,6 +49,6 @@ SEVERITY_MODIFIERS: dict[str, str] = {
|
|
| 49 |
"significant": "significant",
|
| 50 |
}
|
| 51 |
|
| 52 |
-
DEFAULT_MODEL_ID = "
|
| 53 |
-
TEXT_ENCODER_MODEL_ID = "
|
| 54 |
-
PROCESSOR_MODEL_ID = "
|
|
|
|
| 49 |
"significant": "significant",
|
| 50 |
}
|
| 51 |
|
| 52 |
+
DEFAULT_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"
|
| 53 |
+
TEXT_ENCODER_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"
|
| 54 |
+
PROCESSOR_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"
|