MogensR commited on
Commit
b761b10
·
1 Parent(s): 434630c
Files changed (1) hide show
  1. app.py +187 -132
app.py CHANGED
@@ -1,146 +1,138 @@
1
  #!/usr/bin/env python3
2
  """
3
  VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
- ================================================
5
  - Sets up Gradio UI and launches pipeline
6
- - Aligned with torch==2.3.1+cu121, MatAnyone v1.0.0, SAM2 commit 3c76f73c1a7e7b4a2e8a0a9a3e5b92f7e6e3f2f5
7
 
8
- Changes (2025-09-16):
9
- - Enhanced error handling and model verification
10
- - Added GPU memory management
11
- - Improved logging and diagnostics
12
- - Added model verification on startup
13
  """
14
 
15
  print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
 
 
 
 
16
  import sys
17
  import os
18
  import gc
19
- import torch
20
  import logging
21
  import threading
22
  import time
23
  import warnings
24
  import traceback
 
25
  from pathlib import Path
26
  from loguru import logger
27
 
28
- # Configure logging
29
  logger.remove()
30
  logger.add(
31
  sys.stderr,
32
- format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
 
33
  )
34
 
35
- # Set up warnings
36
  warnings.filterwarnings("ignore", category=UserWarning)
37
  warnings.filterwarnings("ignore", category=FutureWarning)
38
  warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
39
 
40
- # Environment setup
41
- os.environ["OMP_NUM_THREADS"] = "1"
42
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
 
44
- # --- Path Configuration ---
45
  BASE_DIR = Path(__file__).parent.absolute()
46
  THIRD_PARTY_DIR = BASE_DIR / "third_party"
47
  SAM2_DIR = THIRD_PARTY_DIR / "sam2"
48
  CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
49
 
50
- # Add to Python path
51
- for p in [str(THIRD_PARTY_DIR), str(SAM2_DIR)]:
52
  if p not in sys.path:
53
  sys.path.insert(0, p)
54
 
55
  logger.info(f"Base directory: {BASE_DIR}")
56
- logger.info(f"Python path: {sys.path}")
 
 
 
 
 
 
 
 
 
57
 
58
- # --- GPU Configuration ---
59
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
60
  if DEVICE == "cuda":
61
  os.environ["SAM2_DEVICE"] = "cuda"
62
  os.environ["MATANY_DEVICE"] = "cuda"
63
- os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
64
- logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
 
 
 
65
  else:
66
  os.environ["SAM2_DEVICE"] = "cpu"
67
  os.environ["MATANY_DEVICE"] = "cpu"
68
  logger.warning("CUDA not available, falling back to CPU")
69
 
70
- # --- Model Verification ---
71
  def verify_models():
72
- """Verify that all required models are available and loadable."""
73
  results = {"status": "success", "details": {}}
74
-
75
- # Check SAM2 model
76
  try:
77
  sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
78
  if not os.path.exists(sam2_model_path):
79
  raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
80
-
81
- # Try to load a small part of the model to verify it's not corrupted
82
- state_dict = torch.load(sam2_model_path, map_location=DEVICE)
83
- if not isinstance(state_dict, dict):
84
- raise ValueError("Invalid SAM2 model format")
85
-
86
  results["details"]["sam2"] = {
87
  "status": "success",
88
  "path": sam2_model_path,
89
- "size_mb": os.path.getsize(sam2_model_path) / (1024 * 1024)
90
  }
91
  except Exception as e:
92
  results["status"] = "error"
93
  results["details"]["sam2"] = {
94
  "status": "error",
95
  "error": str(e),
96
- "traceback": traceback.format_exc()
97
  }
98
-
99
  return results
100
 
101
- # --- Startup Diagnostics ---
102
  def run_startup_diagnostics():
103
- """Run comprehensive system and model diagnostics."""
104
  diag = {
105
  "system": {
106
  "python": sys.version,
107
- "pytorch": torch.__version__,
108
- "cuda_available": torch.cuda.is_available(),
109
- "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
110
- "cuda_version": torch.version.cuda if hasattr(torch.version, 'cuda') else None,
111
  },
112
  "paths": {
113
  "base_dir": str(BASE_DIR),
114
  "checkpoints_dir": str(CHECKPOINTS_DIR),
115
  "sam2_dir": str(SAM2_DIR),
116
- "python_path": sys.path
117
  },
118
- "environment": dict(os.environ)
119
  }
120
-
121
- # Run model verification
122
  diag["model_verification"] = verify_models()
123
-
124
  return diag
125
 
126
- # Run diagnostics on startup
127
  startup_diag = run_startup_diagnostics()
128
  logger.info("Startup diagnostics completed")
129
 
130
- # Import Gradio after environment setup
131
- import gradio as gr
132
-
133
- # -----------------------------------------------------------------------------
134
- # Logging early
135
- # -----------------------------------------------------------------------------
136
- logger = logging.getLogger("backgroundfx_pro")
137
- if not logger.handlers:
138
- h = logging.StreamHandler()
139
- h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
140
- logger.addHandler(h)
141
- logger.setLevel(logging.INFO)
142
-
143
- # Heartbeat so logs never go silent during startup/imports
144
  def _heartbeat():
145
  i = 0
146
  while True:
@@ -150,56 +142,17 @@ def _heartbeat():
150
 
151
  threading.Thread(target=_heartbeat, daemon=True).start()
152
 
153
- # -----------------------------------------------------------------------------
154
- # Safe, minimal startup diagnostics (no long CUDA probes)
155
- # -----------------------------------------------------------------------------
156
- def _safe_startup_diag():
157
- # Torch version
158
- try:
159
- import torch
160
- import importlib
161
- t = importlib.import_module("torch")
162
- logger.info(
163
- "torch imported: %s | torch.version.cuda=%s",
164
- getattr(t, "__version__", "?"),
165
- getattr(getattr(t, "version", None), "cuda", None),
166
- )
167
- except Exception as e:
168
- logger.warning("Torch not available at startup: %s", e)
169
-
170
- # MatAnyone version
171
- try:
172
- import importlib.metadata
173
- version = importlib.metadata.version("matanyone")
174
- logger.info(f"[MATANY] MatAnyone version: {version}")
175
- except Exception:
176
- logger.info("[MATANY] MatAnyone version unknown")
177
-
178
- # nvidia-smi with short timeout (avoid indefinite block)
179
- try:
180
- out = subprocess.run(
181
- ["nvidia-smi", "-L"], capture_output=True, text=True, timeout=2
182
- )
183
- if out.returncode == 0:
184
- logger.info("nvidia-smi -L:\n%s", out.stdout.strip())
185
- else:
186
- logger.warning("nvidia-smi -L failed or unavailable (rc=%s).", out.returncode)
187
- except subprocess.TimeoutExpired:
188
- logger.warning("nvidia-smi -L timed out (skipping).")
189
- except Exception as e:
190
- logger.warning("nvidia-smi not runnable: %s", e)
191
-
192
- # Optional perf tuning; never block startup
193
  try:
194
- import perf_tuning
195
  logger.info("perf_tuning imported successfully.")
196
  except Exception as e:
197
  logger.info("perf_tuning not available: %s", e)
198
 
199
- # MatAnyone API detection probe (non-instantiating)
200
  try:
201
  import inspect
202
- from matanyone.inference import inference_core as ic
203
  sigs = {}
204
  for name in ("InferenceCore",):
205
  obj = getattr(ic, name, None)
@@ -209,47 +162,149 @@ def _safe_startup_diag():
209
  except Exception as e:
210
  logger.info(f"[MATANY] probe skipped: {e}")
211
 
212
- # Continue with app startup
213
- _safe_startup_diag()
 
 
214
 
215
- # -----------------------------------------------------------------------------
216
- # Post-launch CUDA diag in background (so it never blocks binding the port)
217
- # -----------------------------------------------------------------------------
218
- def _post_launch_diag():
 
 
 
 
 
 
219
  try:
220
- import torch
221
  try:
222
- avail = torch.cuda.is_available()
223
- except Exception as e:
224
- logger.warning("torch.cuda.is_available() failed: %s", e)
225
- avail = False
226
- logger.info("CUDA available: %s", avail)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  if avail:
228
- try:
229
- idx = torch.cuda.current_device()
230
- name = torch.cuda.get_device_name(idx)
231
- cap = torch.cuda.get_device_capability(idx)
232
- logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
233
- except Exception as e:
234
- logger.warning("CUDA device query failed: %s", e)
235
  except Exception as e:
236
- logger.warning("Post-launch torch diag failed: %s", e)
237
 
238
- # -----------------------------------------------------------------------------
239
- # Build UI (in separate module) and launch
240
- # -----------------------------------------------------------------------------
241
  def build_ui() -> gr.Blocks:
242
- # Import here so any heavy imports inside ui.py (it shouldn’t) would show up after logs are configured
243
- from ui import create_interface
244
  return create_interface()
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  if __name__ == "__main__":
247
  host = os.environ.get("HOST", "0.0.0.0")
248
  port = int(os.environ.get("PORT", "7860"))
249
- logger.info("Launching Gradio on %s:%s …", host, port)
 
 
 
250
 
251
  demo = build_ui()
252
- demo.queue(max_size=16, api_open=False) # Disable public API for security
 
253
 
254
  threading.Thread(target=_post_launch_diag, daemon=True).start()
255
- demo.launch(server_name=host, server_port=port, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
  VideoBackgroundReplacer2 - SAM2 + MatAnyone Integration
4
+ =======================================================
5
  - Sets up Gradio UI and launches pipeline
6
+ - Aligned with torch/cu121 stack; runs on HF Spaces (Docker)
7
 
8
+ Changes (2025-09-18):
9
+ - Added precise web-stack probes (FastAPI/Starlette/Pydantic/etc. versions + file paths)
10
+ - Added toggleable "mount mode": run Gradio inside our own FastAPI app
11
+ and provide a safe /config route shim (uses demo.get_config_file()).
12
+ - Kept your startup diagnostics, GPU logging, and heartbeats
13
  """
14
 
15
  print("=== APP STARTUP: Initializing VideoBackgroundReplacer2 ===")
16
+
17
+ # ---------------------------------------------------------------------
18
+ # Imports & basic setup
19
+ # ---------------------------------------------------------------------
20
  import sys
21
  import os
22
  import gc
23
+ import json
24
  import logging
25
  import threading
26
  import time
27
  import warnings
28
  import traceback
29
+ import subprocess
30
  from pathlib import Path
31
  from loguru import logger
32
 
33
+ # Logging (loguru to stderr)
34
  logger.remove()
35
  logger.add(
36
  sys.stderr,
37
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> "
38
+ "| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
39
  )
40
 
41
+ # Warnings
42
  warnings.filterwarnings("ignore", category=UserWarning)
43
  warnings.filterwarnings("ignore", category=FutureWarning)
44
  warnings.filterwarnings("ignore", module="torchvision.io._video_deprecation_warning")
45
 
46
+ # Environment (lightweight & safe in Spaces)
47
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
48
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
49
 
50
+ # Paths
51
  BASE_DIR = Path(__file__).parent.absolute()
52
  THIRD_PARTY_DIR = BASE_DIR / "third_party"
53
  SAM2_DIR = THIRD_PARTY_DIR / "sam2"
54
  CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
55
 
56
+ # Python path extends
57
+ for p in (str(THIRD_PARTY_DIR), str(SAM2_DIR)):
58
  if p not in sys.path:
59
  sys.path.insert(0, p)
60
 
61
  logger.info(f"Base directory: {BASE_DIR}")
62
+ logger.info(f"Python path[0:5]: {sys.path[:5]}")
63
+
64
+ # ---------------------------------------------------------------------
65
+ # GPU / Torch diagnostics (non-blocking)
66
+ # ---------------------------------------------------------------------
67
+ try:
68
+ import torch
69
+ except Exception as e:
70
+ logger.warning("Torch import failed at startup: %s", e)
71
+ torch = None
72
 
73
+ DEVICE = "cuda" if (torch and torch.cuda.is_available()) else "cpu"
 
74
  if DEVICE == "cuda":
75
  os.environ["SAM2_DEVICE"] = "cuda"
76
  os.environ["MATANY_DEVICE"] = "cuda"
77
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
78
+ try:
79
+ logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
80
+ except Exception:
81
+ logger.info("CUDA device name not available at startup.")
82
  else:
83
  os.environ["SAM2_DEVICE"] = "cpu"
84
  os.environ["MATANY_DEVICE"] = "cpu"
85
  logger.warning("CUDA not available, falling back to CPU")
86
 
 
87
  def verify_models():
88
+ """Verify critical model files exist and are loadable (cheap checks)."""
89
  results = {"status": "success", "details": {}}
 
 
90
  try:
91
  sam2_model_path = os.getenv("SAM2_MODEL_PATH", str(CHECKPOINTS_DIR / "sam2_hiera_large.pt"))
92
  if not os.path.exists(sam2_model_path):
93
  raise FileNotFoundError(f"SAM2 model not found at {sam2_model_path}")
94
+ # Cheap load test (map to CPU to avoid VRAM use during boot)
95
+ if torch:
96
+ sd = torch.load(sam2_model_path, map_location="cpu")
97
+ if not isinstance(sd, dict):
98
+ raise ValueError("Invalid SAM2 checkpoint format")
 
99
  results["details"]["sam2"] = {
100
  "status": "success",
101
  "path": sam2_model_path,
102
+ "size_mb": round(os.path.getsize(sam2_model_path) / (1024 * 1024), 2),
103
  }
104
  except Exception as e:
105
  results["status"] = "error"
106
  results["details"]["sam2"] = {
107
  "status": "error",
108
  "error": str(e),
109
+ "traceback": traceback.format_exc(),
110
  }
 
111
  return results
112
 
 
113
  def run_startup_diagnostics():
 
114
  diag = {
115
  "system": {
116
  "python": sys.version,
117
+ "pytorch": getattr(torch, "__version__", None) if torch else None,
118
+ "cuda_available": bool(torch and torch.cuda.is_available()),
119
+ "device_count": (torch.cuda.device_count() if torch and torch.cuda.is_available() else 0),
120
+ "cuda_version": getattr(getattr(torch, "version", None), "cuda", None) if torch else None,
121
  },
122
  "paths": {
123
  "base_dir": str(BASE_DIR),
124
  "checkpoints_dir": str(CHECKPOINTS_DIR),
125
  "sam2_dir": str(SAM2_DIR),
 
126
  },
127
+ "env_subset": {k: v for k, v in os.environ.items() if k in ("HOST", "PORT", "SPACE_ID", "SPACE_AUTHOR_NAME")},
128
  }
 
 
129
  diag["model_verification"] = verify_models()
 
130
  return diag
131
 
 
132
  startup_diag = run_startup_diagnostics()
133
  logger.info("Startup diagnostics completed")
134
 
135
+ # Noisy heartbeat so logs show life during import time
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def _heartbeat():
137
  i = 0
138
  while True:
 
142
 
143
  threading.Thread(target=_heartbeat, daemon=True).start()
144
 
145
+ # Optional perf tuning import (non-fatal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  try:
147
+ import perf_tuning # noqa: F401
148
  logger.info("perf_tuning imported successfully.")
149
  except Exception as e:
150
  logger.info("perf_tuning not available: %s", e)
151
 
152
+ # MatAnyone non-instantiating probe
153
  try:
154
  import inspect
155
+ from matanyone.inference import inference_core as ic # type: ignore
156
  sigs = {}
157
  for name in ("InferenceCore",):
158
  obj = getattr(ic, name, None)
 
162
  except Exception as e:
163
  logger.info(f"[MATANY] probe skipped: {e}")
164
 
165
+ # ---------------------------------------------------------------------
166
+ # Gradio import and web-stack probes
167
+ # ---------------------------------------------------------------------
168
+ import gradio as gr
169
 
170
+ # Standard logger for some libs that use stdlib logging
171
+ py_logger = logging.getLogger("backgroundfx_pro")
172
+ if not py_logger.handlers:
173
+ h = logging.StreamHandler()
174
+ h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s"))
175
+ py_logger.addHandler(h)
176
+ py_logger.setLevel(logging.INFO)
177
+
178
+ def _log_web_stack_versions_and_paths():
179
+ import inspect
180
  try:
181
+ import fastapi, starlette, pydantic, httpx, anyio
182
  try:
183
+ import pydantic_core
184
+ pc_ver = pydantic_core.__version__
185
+ except Exception:
186
+ pc_ver = "unknown"
187
+ logger.info(
188
+ "[WEB-STACK] fastapi=%s | starlette=%s | pydantic=%s | pydantic-core=%s | httpx=%s | anyio=%s",
189
+ getattr(fastapi, "__version__", "?"),
190
+ getattr(starlette, "__version__", "?"),
191
+ getattr(pydantic, "__version__", "?"),
192
+ pc_ver,
193
+ getattr(httpx, "__version__", "?"),
194
+ getattr(anyio, "__version__", "?"),
195
+ )
196
+ except Exception as e:
197
+ logger.warning("[WEB-STACK] version probe failed: %s", e)
198
+
199
+ try:
200
+ import gradio
201
+ import gradio.routes as gr_routes
202
+ import gradio.queueing as gr_queueing
203
+ logger.info("[PATH] gradio.__file__ = %s", getattr(gradio, "__file__", "?"))
204
+ logger.info("[PATH] gradio.routes = %s", inspect.getfile(gr_routes))
205
+ logger.info("[PATH] gradio.queueing = %s", inspect.getfile(gr_queueing))
206
+ import starlette.exceptions as st_exc
207
+ logger.info("[PATH] starlette.exceptions= %s", inspect.getfile(st_exc))
208
+ except Exception as e:
209
+ logger.warning("[PATH] probe failed: %s", e)
210
+
211
+ def _post_launch_diag():
212
+ try:
213
+ if not torch:
214
+ return
215
+ avail = torch.cuda.is_available()
216
+ logger.info("CUDA available (post-launch): %s", avail)
217
  if avail:
218
+ idx = torch.cuda.current_device()
219
+ name = torch.cuda.get_device_name(idx)
220
+ cap = torch.cuda.get_device_capability(idx)
221
+ logger.info("CUDA device %d: %s (cc %d.%d)", idx, name, cap[0], cap[1])
 
 
 
222
  except Exception as e:
223
+ logger.warning("Post-launch CUDA diag failed: %s", e)
224
 
225
+ # ---------------------------------------------------------------------
226
+ # UI factory (your existing UI builder)
227
+ # ---------------------------------------------------------------------
228
  def build_ui() -> gr.Blocks:
229
+ from ui import create_interface # your module
 
230
  return create_interface()
231
 
232
+ # ---------------------------------------------------------------------
233
+ # Optional: custom FastAPI mount mode
234
+ # Why: if Gradio’s internal FastAPI route schema generation glitches
235
+ # under certain dependency combos, we can run our own FastAPI app,
236
+ # mount Gradio, and provide a known-good /config shim.
237
+ # ---------------------------------------------------------------------
238
+ def build_fastapi_with_gradio(demo: gr.Blocks):
239
+ """
240
+ Returns a FastAPI app with Gradio mounted at root.
241
+ Also exposes JSON health and a config shim using demo.get_config_file().
242
+ """
243
+ from fastapi import FastAPI
244
+ from fastapi.responses import JSONResponse
245
+
246
+ app = FastAPI(title="VideoBackgroundReplacer2")
247
+
248
+ # Simple health
249
+ @app.get("/healthz")
250
+ def _healthz():
251
+ return {"ok": True, "ts": time.time()}
252
+
253
+ # Config shim — this bypasses Gradio's internal /config route
254
+ # and returns the same structure the frontend expects.
255
+ @app.get("/config")
256
+ def _config():
257
+ try:
258
+ cfg = demo.get_config_file() # Gradio builds the JSON-able config dict
259
+ return JSONResponse(content=cfg)
260
+ except Exception as e:
261
+ # If something fails, return explicit JSON (so the frontend won't choke on HTML)
262
+ return JSONResponse(
263
+ status_code=500,
264
+ content={"error": "config_generation_failed", "detail": str(e)},
265
+ )
266
+
267
+ # Mount Gradio UI at root; static assets & index served by Gradio
268
+ app = gr.mount_gradio_app(app, demo, path="/")
269
+ return app
270
+
271
+ # ---------------------------------------------------------------------
272
+ # Entrypoint
273
+ # ---------------------------------------------------------------------
274
  if __name__ == "__main__":
275
  host = os.environ.get("HOST", "0.0.0.0")
276
  port = int(os.environ.get("PORT", "7860"))
277
+ mount_mode = os.environ.get("GRADIO_MOUNT_MODE", "0") == "1"
278
+
279
+ logger.info("Launching on %s:%s (mount_mode=%s)…", host, port, mount_mode)
280
+ _log_web_stack_versions_and_paths()
281
 
282
  demo = build_ui()
283
+ # Good defaults for Spaces
284
+ demo.queue(max_size=16, api_open=False)
285
 
286
  threading.Thread(target=_post_launch_diag, daemon=True).start()
287
+
288
+ if mount_mode:
289
+ # Our own FastAPI + /config shim
290
+ try:
291
+ from uvicorn import run as uvicorn_run
292
+ except Exception:
293
+ logger.error("uvicorn is not installed; mount mode cannot start.")
294
+ raise
295
+
296
+ app = build_fastapi_with_gradio(demo)
297
+ # NOTE: In Docker Spaces, this process is PID1; we call uvicorn.run programmatically.
298
+ uvicorn_run(app=app, host=host, port=port, log_level="info")
299
+ else:
300
+ # Standard Gradio server (uses internal FastAPI app & routes)
301
+ demo.launch(
302
+ server_name=host,
303
+ server_port=port,
304
+ share=False,
305
+ show_api=False, # keep off in Spaces
306
+ show_error=True,
307
+ quiet=False,
308
+ debug=True,
309
+ max_threads=1,
310
+ )