gradientguild commited on
Commit
463e35b
·
verified ·
1 Parent(s): a4aa5c5

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +15 -13
  2. requirements.txt +1 -0
  3. synthcxr/constants.py +3 -3
app.py CHANGED
@@ -4,6 +4,8 @@
4
  from __future__ import annotations
5
 
6
  import os
 
 
7
  from pathlib import Path
8
 
9
  import spaces
@@ -37,17 +39,18 @@ CONDITION_CHOICES = [
37
  SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
38
 
39
  # ---------------------------------------------------------------------------
40
- # Pipeline (lazy-loaded once)
41
  # ---------------------------------------------------------------------------
42
- _pipe = None
43
 
44
 
45
- def get_pipeline():
46
- """Load the diffusion pipeline + LoRA weights into GPU memory (once)."""
47
- global _pipe
48
- if _pipe is not None:
49
- return _pipe
50
 
 
 
 
 
 
51
  from synthcxr.pipeline import load_lora_weights, load_pipeline
52
 
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -58,26 +61,25 @@ def get_pipeline():
58
  vram_limit = float(vram_limit_str) if vram_limit_str else None
59
 
60
  print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
61
- _pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
62
 
63
  # LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
64
  lora_epoch = os.environ.get("LORA_EPOCH", "2")
65
  lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
66
 
67
  if not lora.exists():
68
- # Try step-based checkpoints or any available .safetensors
69
  candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
70
  if candidates:
71
  lora = candidates[-1]
72
  print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
73
  else:
74
  print("[WARN] No LoRA checkpoint found – running base model only.")
75
- return _pipe
76
 
77
  print(f"[INFO] Loading LoRA from {lora}")
78
- load_lora_weights(_pipe, lora)
79
  print("[INFO] Pipeline ready.")
80
- return _pipe
81
 
82
 
83
  # ---------------------------------------------------------------------------
@@ -164,7 +166,7 @@ def generate_cxr(
164
  if mask_image is None:
165
  raise gr.Error("Please select or upload a mask first.")
166
 
167
- pipe = get_pipeline()
168
  if pipe is None:
169
  raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
170
 
 
4
  from __future__ import annotations
5
 
6
  import os
7
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
+
9
  from pathlib import Path
10
 
11
  import spaces
 
39
  SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
40
 
41
  # ---------------------------------------------------------------------------
42
+ # Pipeline loading (fresh on each @spaces.GPU call; model files cached on disk)
43
  # ---------------------------------------------------------------------------
 
44
 
45
 
46
+ def load_fresh_pipeline():
47
+ """Load the pipeline + LoRA onto the *currently allocated* GPU.
 
 
 
48
 
49
+ ZeroGPU deallocates GPU memory after each ``@spaces.GPU`` call, so we
50
+ cannot cache tensors between calls. However, diffsynth caches the
51
+ model files on disk (HF Hub cache), so only tensor loading happens
52
+ here — not a full download.
53
+ """
54
  from synthcxr.pipeline import load_lora_weights, load_pipeline
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
61
  vram_limit = float(vram_limit_str) if vram_limit_str else None
62
 
63
  print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
64
+ pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
65
 
66
  # LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
67
  lora_epoch = os.environ.get("LORA_EPOCH", "2")
68
  lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
69
 
70
  if not lora.exists():
 
71
  candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
72
  if candidates:
73
  lora = candidates[-1]
74
  print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
75
  else:
76
  print("[WARN] No LoRA checkpoint found – running base model only.")
77
+ return pipe
78
 
79
  print(f"[INFO] Loading LoRA from {lora}")
80
+ load_lora_weights(pipe, lora)
81
  print("[INFO] Pipeline ready.")
82
+ return pipe
83
 
84
 
85
  # ---------------------------------------------------------------------------
 
166
  if mask_image is None:
167
  raise gr.Error("Please select or upload a mask first.")
168
 
169
+ pipe = load_fresh_pipeline()
170
  if pipe is None:
171
  raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
172
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ scipy
5
  Pillow
6
  numpy
7
  torch
 
 
5
  Pillow
6
  numpy
7
  torch
8
+ hf_transfer
synthcxr/constants.py CHANGED
@@ -49,6 +49,6 @@ SEVERITY_MODIFIERS: dict[str, str] = {
49
  "significant": "significant",
50
  }
51
 
52
- DEFAULT_MODEL_ID = "Qwen/Qwen-Image-Edit-2511"
53
- TEXT_ENCODER_MODEL_ID = "Qwen/Qwen-Image"
54
- PROCESSOR_MODEL_ID = "Qwen/Qwen-Image-Edit"
 
49
  "significant": "significant",
50
  }
51
 
52
+ DEFAULT_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"
53
+ TEXT_ENCODER_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"
54
+ PROCESSOR_MODEL_ID = "gradientguild/SynthCXR-Qwen-Weights"