ihtesham0345 commited on
Commit
7f02891
Β·
1 Parent(s): 28905c9

Move InvSR download to runtime (avoids build OOM), add status endpoint

Browse files
Files changed (2) hide show
  1. Dockerfile +0 -14
  2. app.py +67 -75
Dockerfile CHANGED
@@ -9,25 +9,11 @@ USER user
9
  ENV PATH="/home/user/.local/bin:$PATH"
10
  WORKDIR /app
11
 
12
- # Clone InvSR source
13
  RUN git clone --depth 1 https://github.com/zsyOAOA/InvSR.git /app/InvSR
14
 
15
- # Install pip deps
16
  COPY --chown=user requirements.txt .
17
  RUN pip install --no-cache-dir -r requirements.txt
18
 
19
- # Pre-download SD-Turbo (~5GB) + noise predictor at BUILD time
20
- RUN python -c "\
21
- from huggingface_hub import snapshot_download;\
22
- snapshot_download('stabilityai/sd-turbo', cache_dir='/app/models/invsr');\
23
- print('SD-Turbo downloaded')\
24
- "
25
- RUN python -c "\
26
- from huggingface_hub import hf_hub_download;\
27
- hf_hub_download('OAOA/InvSR', 'noise_predictor_sd_turbo_v5.pth', cache_dir='/app/models/invsr');\
28
- print('Noise predictor downloaded')\
29
- "
30
-
31
  COPY --chown=user app.py .
32
 
33
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"]
 
9
  ENV PATH="/home/user/.local/bin:$PATH"
10
  WORKDIR /app
11
 
 
12
  RUN git clone --depth 1 https://github.com/zsyOAOA/InvSR.git /app/InvSR
13
 
 
14
  COPY --chown=user requirements.txt .
15
  RUN pip install --no-cache-dir -r requirements.txt
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  COPY --chown=user app.py .
18
 
19
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"]
app.py CHANGED
@@ -78,81 +78,58 @@ def upscale_mewzoom(image_bytes: bytes, scale: str) -> tuple[bytes, dict]:
78
  # ── InvSR Model (Diffusion 4X) ──────────────────────────────
79
  _INVSR_PATH = Path("/app/InvSR")
80
  _sampler_invsr = None
 
 
81
 
82
 
83
- def _patch_invsr():
84
- """Patch InvSR source for CPU/float32 support."""
85
- p = _INVSR_PATH / "sampler_invsr.py"
86
- code = p.read_text()
87
-
88
- # Remove basicsr import chain (not needed for inference)
89
- code = code.replace("from datapipe.datasets import create_dataset", "")
90
-
91
- # Add device param to BaseSampler
92
- old_init = """class BaseSampler:
93
- def __init__(self, configs):
94
- '''
95
- Input:
96
- configs: config, see the yaml file in folder ./configs/
97
- configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}
98
- seed: int, random seed
99
- '''
100
- self.configs = configs
101
-
102
- self.setup_seed()
103
-
104
- self.build_model()
105
-
106
- def setup_seed(self, seed=None):
107
- seed = self.configs.seed if seed is None else seed
108
- random.seed(seed)
109
- np.random.seed(seed)
110
- torch.manual_seed(seed)
111
- torch.cuda.manual_seed_all(seed)"""
112
-
113
- new_init = """class BaseSampler:
114
- def __init__(self, configs, device='auto'):
115
- self.configs = configs
116
- if device == 'auto':
117
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
118
- self.device = torch.device(device)
119
- self.dtype = torch.float16 if self.device.type == 'cuda' else torch.float32
120
- self.setup_seed()
121
- self.build_model()
122
-
123
- def setup_seed(self, seed=None):
124
- seed = self.configs.seed if seed is None else seed
125
- random.seed(seed)
126
- np.random.seed(seed)
127
- torch.manual_seed(seed)
128
- if torch.cuda.is_available():
129
- torch.cuda.manual_seed_all(seed)"""
130
-
131
- code = code.replace(old_init, new_init)
132
-
133
- # Replace .cuda() and .type(torch.float16) with device-aware versions
134
- code = code.replace('sd_pipe.to(f"cuda")', "sd_pipe.to(self.device)")
135
- code = code.replace("model_start.cuda()", "model_start.to(self.device)")
136
- code = code.replace('map_location=f"cuda"', "map_location=self.device")
137
- code = code.replace("im_cond.type(torch.float16)", "im_cond.type(self.dtype)")
138
- code = code.replace(".type(torch.float16)", ".type(self.dtype)")
139
- code = code.replace("data['lq'].cuda()", "data['lq'].to(self.device)")
140
- code = code.replace("util_image.img2tensor(im_cond).cuda()", "util_image.img2tensor(im_cond).to(self.device)")
141
-
142
- # Lazy import create_dataset in inference method
143
- code = code.replace(
144
- "if in_path.is_dir():\n data_config",
145
- "if in_path.is_dir():\n from datapipe.datasets import create_dataset\n data_config",
146
- )
147
-
148
- p.write_text(code)
149
- logger.info("InvSR sampler patched for CPU/float32")
150
-
151
-
152
- def _load_invsr():
153
- global _sampler_invsr
154
- if _sampler_invsr is not None:
155
- return _sampler_invsr
156
 
157
  _patch_invsr()
158
  sys.path.insert(0, str(_INVSR_PATH))
@@ -190,7 +167,12 @@ def _load_invsr():
190
 
191
 
192
  def upscale_invsr(image_bytes: bytes, num_steps: int = 1) -> bytes:
193
- sampler = _load_invsr()
 
 
 
 
 
194
  sys.path.insert(0, str(_INVSR_PATH))
195
  from utils import util_image
196
 
@@ -271,6 +253,9 @@ async def lifespan(app: FastAPI):
271
  logger.info("Loading MewZoom models...")
272
  for s in MEWZOOM_MODELS:
273
  _load_mewzoom(s)
 
 
 
274
  yield
275
 
276
 
@@ -286,7 +271,14 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], all
286
  @app.get("/")
287
  @app.get("/health")
288
  async def health():
289
- return JSONResponse({"status": "healthy", "device": _DEVICE, "models": ["2x", "4x", "invsr"], "gpu": torch.cuda.is_available()})
 
 
 
 
 
 
 
290
 
291
 
292
  @app.post("/upscale/2x")
 
78
  # ── InvSR Model (Diffusion 4X) ──────────────────────────────
79
  _INVSR_PATH = Path("/app/InvSR")
80
  _sampler_invsr = None
81
+ _invsr_status = "not_loaded"
82
+ _invsr_error = None
83
 
84
 
85
+ def _load_invsr_sync():
86
+ """Download + load InvSR (called in background during startup)"""
87
+ global _sampler_invsr, _invsr_status, _invsr_error
88
+ try:
89
+ _invsr_status = "downloading"
90
+ _patch_invsr()
91
+ sys.path.insert(0, str(_INVSR_PATH))
92
+ sys.path.insert(0, str(_INVSR_PATH / "src"))
93
+
94
+ from omegaconf import OmegaConf
95
+ from huggingface_hub import snapshot_download, hf_hub_download
96
+
97
+ invsr_cache = str(CACHE_DIR / "invsr")
98
+ CACHE_DIR.mkdir(exist_ok=True)
99
+
100
+ logger.info("Downloading SD-Turbo (~5GB, one-time)...")
101
+ snapshot_download("stabilityai/sd-turbo", cache_dir=invsr_cache, resume_download=True)
102
+ logger.info("SD-Turbo downloaded")
103
+
104
+ logger.info("Downloading noise predictor...")
105
+ hf_hub_download("OAOA/InvSR", "noise_predictor_sd_turbo_v5.pth", cache_dir=invsr_cache)
106
+ ckpt = None
107
+ for f in Path(invsr_cache).rglob("noise_predictor_sd_turbo_v5.pth"):
108
+ ckpt = str(f); break
109
+ if not ckpt:
110
+ raise FileNotFoundError("Noise predictor not found after download")
111
+
112
+ _invsr_status = "loading"
113
+ from sampler_invsr import InvSamplerSR
114
+ cfg = OmegaConf.load(str(_INVSR_PATH / "configs" / "sample-sd-turbo.yaml"))
115
+ cfg.sd_pipe.params.torch_dtype = "torch.float32" if _DEVICE == "cpu" else "torch.float16"
116
+ cfg.sd_pipe.params.cache_dir = invsr_cache
117
+ cfg.sd_pipe.params.local_files_only = True
118
+ cfg.model_start.ckpt_path = ckpt
119
+ cfg.timesteps = [200]; cfg.bs = 1; cfg.tiled_vae = True
120
+ cfg.color_fix = "wavelet"; cfg.basesr.chopping.pch_size = 128
121
+ cfg.basesr.chopping.extra_bs = 8
122
+
123
+ logger.info("Loading InvSR into memory...")
124
+ _sampler_invsr = InvSamplerSR(cfg, device="auto")
125
+ if _DEVICE == "cpu":
126
+ _sampler_invsr.sd_pipe = _sampler_invsr.sd_pipe.to(dtype=torch.float32)
127
+ _invsr_status = "ready"
128
+ logger.info("InvSR ready on %s", _DEVICE)
129
+ except Exception as e:
130
+ _invsr_status = "error"
131
+ _invsr_error = str(e)
132
+ logger.error("InvSR load failed: %s", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  _patch_invsr()
135
  sys.path.insert(0, str(_INVSR_PATH))
 
167
 
168
 
169
  def upscale_invsr(image_bytes: bytes, num_steps: int = 1) -> bytes:
170
+ global _sampler_invsr
171
+ if _invsr_status == "error":
172
+ raise HTTPException(500, f"InvSR failed to load: {_invsr_error}")
173
+ if _sampler_invsr is None:
174
+ raise HTTPException(503, f"InvSR is {_invsr_status}. Check /health for status.")
175
+ sampler = _sampler_invsr
176
  sys.path.insert(0, str(_INVSR_PATH))
177
  from utils import util_image
178
 
 
253
  logger.info("Loading MewZoom models...")
254
  for s in MEWZOOM_MODELS:
255
  _load_mewzoom(s)
256
+ # Start InvSR download+load in background thread
257
+ import threading
258
+ threading.Thread(target=_load_invsr_sync, daemon=True).start()
259
  yield
260
 
261
 
 
271
  @app.get("/")
272
  @app.get("/health")
273
  async def health():
274
+ return JSONResponse({
275
+ "status": "healthy",
276
+ "device": _DEVICE,
277
+ "models": list(MEWZOOM_MODELS.keys()) + ["invsr"],
278
+ "gpu": torch.cuda.is_available(),
279
+ "invsr_status": _invsr_status,
280
+ "invsr_error": _invsr_error,
281
+ })
282
 
283
 
284
  @app.post("/upscale/2x")