MogensR commited on
Commit
d465483
·
verified ·
1 Parent(s): 8631135

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +28 -7
streamlit_app.py CHANGED
@@ -15,12 +15,9 @@
15
  import uuid
16
  from datetime import datetime
17
  from tempfile import NamedTemporaryFile
18
- import subprocess
19
  import streamlit as st
20
- import cv2
21
- import numpy as np
22
- from PIL import Image
23
  import torch
 
24
  # Import UI components
25
  from ui import render_ui
26
  # Import pipeline functions
@@ -30,11 +27,15 @@
30
  setup_t4_environment,
31
  check_gpu
32
  )
 
 
 
33
  # --- Constants ---
34
  APP_NAME = "Advanced Video Background Replacer"
35
  LOG_FILE = "/tmp/app.log"
36
  LOG_MAX_BYTES = 5 * 1024 * 1024
37
  LOG_BACKUPS = 5
 
38
  # --- Logging Setup ---
39
  def setup_logging(level: int = logging.INFO) -> logging.Logger:
40
  logger = logging.getLogger(APP_NAME)
@@ -56,11 +57,18 @@ def setup_logging(level: int = logging.INFO) -> logging.Logger:
56
  logger.addHandler(ch)
57
  logger.addHandler(fh)
58
  return logger
 
59
  logger = setup_logging()
 
60
  # --- Global Exception Hook ---
61
  def custom_excepthook(type, value, tb):
62
  logger.error(f"Unhandled: {type.__name__}: {value}\n{''.join(traceback.format_tb(tb))}", exc_info=True)
63
  sys.excepthook = custom_excepthook
 
 
 
 
 
64
  # --- Session State Initialization ---
65
  def initialize_session_state():
66
  defaults = {
@@ -87,6 +95,7 @@ def initialize_session_state():
87
  st.session_state[k] = v
88
  if st.session_state.gpu_available is None:
89
  st.session_state.gpu_available = check_gpu(logger)
 
90
  # --- Set Log Level ---
91
  def set_log_level(name: str):
92
  name = (name or "INFO").upper()
@@ -97,6 +106,7 @@ def set_log_level(name: str):
97
  logger.setLevel(lvl)
98
  st.session_state.log_level_name = name
99
  logger.info(f"Log level set to {name}")
 
100
  # --- Main Processing Function ---
101
  def process_video(uploaded_video, bg_image, bg_color, bg_type):
102
  run_id = uuid.uuid4().hex[:8]
@@ -118,13 +128,22 @@ def process_video(uploaded_video, bg_image, bg_color, bg_type):
118
  tmp_vid_path = tmp_vid.name
119
  logger.info(f"[RUN {run_id}] Temporary video path: {tmp_vid_path}")
120
  # Stage 1: Create transparent video and extract audio
121
- transparent_path, audio_path = stage1_create_transparent_video(tmp_vid_path)
 
 
 
 
122
  if not transparent_path or not os.path.exists(transparent_path):
123
  raise RuntimeError("Stage 1 failed: Transparent video not created")
124
  logger.info(f"[RUN {run_id}] Stage 1 completed: Transparent path={transparent_path}, Audio path={audio_path}")
125
  # Stage 2: Composite with background and restore audio
126
  background = bg_image.convert("RGB") if bg_type == "Image" else bg_color
127
- final_path = stage2_composite_background(transparent_path, audio_path, background, bg_type.lower())
 
 
 
 
 
128
  if not final_path or not os.path.exists(final_path):
129
  raise RuntimeError("Stage 2 failed: Final video not created")
130
  logger.info(f"[RUN {run_id}] Stage 2 completed: Final path={final_path}")
@@ -144,6 +163,7 @@ def process_video(uploaded_video, bg_image, bg_color, bg_type):
144
  finally:
145
  st.session_state.processing = False
146
  logger.info(f"[RUN {run_id}] Processing finished")
 
147
  # --- Main App Entry Point ---
148
  def main():
149
  try:
@@ -158,6 +178,7 @@ def main():
158
  except Exception as e:
159
  logger.error(f"Main app error: {e}", exc_info=True)
160
  st.error(f"App startup failed: {str(e)}. Check logs for details.")
 
161
  if __name__ == "__main__":
162
  setup_t4_environment()
163
- main()
 
15
  import uuid
16
  from datetime import datetime
17
  from tempfile import NamedTemporaryFile
 
18
  import streamlit as st
 
 
 
19
  import torch
20
+
21
  # Import UI components
22
  from ui import render_ui
23
  # Import pipeline functions
 
27
  setup_t4_environment,
28
  check_gpu
29
  )
30
+ # Import model loaders (the new robust way)
31
+ from models.model_loaders import load_sam2, load_matanyone
32
+
33
  # --- Constants ---
34
  APP_NAME = "Advanced Video Background Replacer"
35
  LOG_FILE = "/tmp/app.log"
36
  LOG_MAX_BYTES = 5 * 1024 * 1024
37
  LOG_BACKUPS = 5
38
+
39
  # --- Logging Setup ---
40
  def setup_logging(level: int = logging.INFO) -> logging.Logger:
41
  logger = logging.getLogger(APP_NAME)
 
57
  logger.addHandler(ch)
58
  logger.addHandler(fh)
59
  return logger
60
+
61
  logger = setup_logging()
62
+
63
  # --- Global Exception Hook ---
64
  def custom_excepthook(type, value, tb):
65
  logger.error(f"Unhandled: {type.__name__}: {value}\n{''.join(traceback.format_tb(tb))}", exc_info=True)
66
  sys.excepthook = custom_excepthook
67
+
68
+ # --- Load models ONCE globally (never per session) ---
69
+ sam2_predictor = load_sam2()
70
+ matanyone_processor = load_matanyone()
71
+
72
  # --- Session State Initialization ---
73
  def initialize_session_state():
74
  defaults = {
 
95
  st.session_state[k] = v
96
  if st.session_state.gpu_available is None:
97
  st.session_state.gpu_available = check_gpu(logger)
98
+
99
  # --- Set Log Level ---
100
  def set_log_level(name: str):
101
  name = (name or "INFO").upper()
 
106
  logger.setLevel(lvl)
107
  st.session_state.log_level_name = name
108
  logger.info(f"Log level set to {name}")
109
+
110
  # --- Main Processing Function ---
111
  def process_video(uploaded_video, bg_image, bg_color, bg_type):
112
  run_id = uuid.uuid4().hex[:8]
 
128
  tmp_vid_path = tmp_vid.name
129
  logger.info(f"[RUN {run_id}] Temporary video path: {tmp_vid_path}")
130
  # Stage 1: Create transparent video and extract audio
131
+ transparent_path, audio_path = stage1_create_transparent_video(
132
+ tmp_vid_path,
133
+ sam2_predictor=sam2_predictor,
134
+ matanyone_processor=matanyone_processor
135
+ )
136
  if not transparent_path or not os.path.exists(transparent_path):
137
  raise RuntimeError("Stage 1 failed: Transparent video not created")
138
  logger.info(f"[RUN {run_id}] Stage 1 completed: Transparent path={transparent_path}, Audio path={audio_path}")
139
  # Stage 2: Composite with background and restore audio
140
  background = bg_image.convert("RGB") if bg_type == "Image" else bg_color
141
+ final_path = stage2_composite_background(
142
+ transparent_path,
143
+ audio_path,
144
+ background,
145
+ bg_type.lower()
146
+ )
147
  if not final_path or not os.path.exists(final_path):
148
  raise RuntimeError("Stage 2 failed: Final video not created")
149
  logger.info(f"[RUN {run_id}] Stage 2 completed: Final path={final_path}")
 
163
  finally:
164
  st.session_state.processing = False
165
  logger.info(f"[RUN {run_id}] Processing finished")
166
+
167
  # --- Main App Entry Point ---
168
  def main():
169
  try:
 
178
  except Exception as e:
179
  logger.error(f"Main app error: {e}", exc_info=True)
180
  st.error(f"App startup failed: {str(e)}. Check logs for details.")
181
+
182
  if __name__ == "__main__":
183
  setup_t4_environment()
184
+ main()