MogensR commited on
Commit
77308ac
·
verified ·
1 Parent(s): 863a27d

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +9 -41
streamlit_app.py CHANGED
@@ -1,54 +1,41 @@
1
  #!/usr/bin/env python3
2
  """
3
- Core Application Logic for Video Background Replacer
4
- - Handles video processing pipeline (SAM2 + MatAnyone + FFmpeg)
5
- - Integrates with UI (imported from ui.py)
6
- - Manages session state and logging
7
  """
8
  import os
9
  import sys
10
  import time
11
  from pathlib import Path
12
  import logging
13
- import logging.handlers
14
  import traceback
15
  import uuid
16
- from datetime import datetime
17
  from tempfile import NamedTemporaryFile
18
  import streamlit as st
19
- import torch
20
 
21
- # Import UI components
22
  from ui import render_ui
23
- # Import pipeline functions
24
  from pipeline.video_pipeline import (
25
  stage1_create_transparent_video,
26
  stage2_composite_background,
27
  setup_t4_environment,
28
  check_gpu
29
  )
30
- # Import model loaders (the new robust way)
31
  from models.model_loaders import load_sam2, load_matanyone
32
 
33
- # --- Constants ---
34
  APP_NAME = "Advanced Video Background Replacer"
35
  LOG_FILE = "/tmp/app.log"
36
  LOG_MAX_BYTES = 5 * 1024 * 1024
37
  LOG_BACKUPS = 5
38
 
39
- # --- Logging Setup ---
40
  def setup_logging(level: int = logging.INFO) -> logging.Logger:
41
  logger = logging.getLogger(APP_NAME)
42
  logger.setLevel(level)
43
  logger.propagate = False
44
- # Clear previous handlers on rerun
45
  for h in list(logger.handlers):
46
  logger.removeHandler(h)
47
- # Console handler
48
  ch = logging.StreamHandler(sys.stdout)
49
  ch.setLevel(level)
50
  ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
51
- # Rotating file handler
52
  fh = logging.handlers.RotatingFileHandler(
53
  LOG_FILE, maxBytes=LOG_MAX_BYTES, backupCount=LOG_BACKUPS, encoding="utf-8"
54
  )
@@ -60,16 +47,14 @@ def setup_logging(level: int = logging.INFO) -> logging.Logger:
60
 
61
  logger = setup_logging()
62
 
63
- # --- Global Exception Hook ---
64
  def custom_excepthook(type, value, tb):
65
  logger.error(f"Unhandled: {type.__name__}: {value}\n{''.join(traceback.format_tb(tb))}", exc_info=True)
66
  sys.excepthook = custom_excepthook
67
 
68
- # --- Load models ONCE globally (never per session) ---
69
  sam2_predictor = load_sam2()
70
  matanyone_processor = load_matanyone()
71
 
72
- # --- Session State Initialization ---
73
  def initialize_session_state():
74
  defaults = {
75
  'uploaded_video': None,
@@ -88,7 +73,8 @@ def initialize_session_state():
88
  'last_error': None,
89
  'log_level_name': 'INFO',
90
  'auto_refresh_logs': False,
91
- 'log_tail_lines': 400
 
92
  }
93
  for k, v in defaults.items():
94
  if k not in st.session_state:
@@ -96,38 +82,23 @@ def initialize_session_state():
96
  if st.session_state.gpu_available is None:
97
  st.session_state.gpu_available = check_gpu(logger)
98
 
99
- # --- Set Log Level ---
100
- def set_log_level(name: str):
101
- name = (name or "INFO").upper()
102
- lvl = getattr(logging, name, logging.INFO)
103
- setup_logging(lvl)
104
- global logger
105
- logger = logging.getLogger(APP_NAME)
106
- logger.setLevel(lvl)
107
- st.session_state.log_level_name = name
108
- logger.info(f"Log level set to {name}")
109
-
110
- # --- Main Processing Function ---
111
- def process_video(uploaded_video, bg_image, bg_color, bg_type):
112
  run_id = uuid.uuid4().hex[:8]
113
  logger.info("=" * 80)
114
- logger.info(f"[RUN {run_id}] VIDEO PROCESSING STARTED at {datetime.utcnow().isoformat()}Z")
115
  logger.info(f"[RUN {run_id}] Video size={len(uploaded_video.read()) / 1e6:.2f}MB, BG type={bg_type}")
116
- uploaded_video.seek(0) # Reset for later read
117
- logger.info("=" * 80)
118
  st.session_state.processing = True
119
  st.session_state.processed_video_bytes = None
120
  st.session_state.last_error = None
121
  t0 = time.time()
122
  try:
123
- # Materialize uploaded video to temp file
124
  suffix = Path(uploaded_video.name).suffix or ".mp4"
125
  with NamedTemporaryFile(delete=False, suffix=suffix) as tmp_vid:
126
  uploaded_video.seek(0)
127
  tmp_vid.write(uploaded_video.read())
128
  tmp_vid_path = tmp_vid.name
129
  logger.info(f"[RUN {run_id}] Temporary video path: {tmp_vid_path}")
130
- # Stage 1: Create transparent video and extract audio
131
  transparent_path, audio_path = stage1_create_transparent_video(
132
  tmp_vid_path,
133
  sam2_predictor=sam2_predictor,
@@ -136,8 +107,7 @@ def process_video(uploaded_video, bg_image, bg_color, bg_type):
136
  if not transparent_path or not os.path.exists(transparent_path):
137
  raise RuntimeError("Stage 1 failed: Transparent video not created")
138
  logger.info(f"[RUN {run_id}] Stage 1 completed: Transparent path={transparent_path}, Audio path={audio_path}")
139
- # Stage 2: Composite with background and restore audio
140
- background = bg_image.convert("RGB") if bg_type == "Image" else bg_color
141
  final_path = stage2_composite_background(
142
  transparent_path,
143
  audio_path,
@@ -147,7 +117,6 @@ def process_video(uploaded_video, bg_image, bg_color, bg_type):
147
  if not final_path or not os.path.exists(final_path):
148
  raise RuntimeError("Stage 2 failed: Final video not created")
149
  logger.info(f"[RUN {run_id}] Stage 2 completed: Final path={final_path}")
150
- # Load final video into memory for download
151
  with open(final_path, 'rb') as f:
152
  st.session_state.processed_video_bytes = f.read()
153
  total = time.time() - t0
@@ -164,7 +133,6 @@ def process_video(uploaded_video, bg_image, bg_color, bg_type):
164
  st.session_state.processing = False
165
  logger.info(f"[RUN {run_id}] Processing finished")
166
 
167
- # --- Main App Entry Point ---
168
  def main():
169
  try:
170
  st.set_page_config(
 
1
  #!/usr/bin/env python3
2
  """
3
+ Advanced Video Background Replacer - Streamlit Entrypoint
 
 
 
4
  """
5
  import os
6
  import sys
7
  import time
8
  from pathlib import Path
9
  import logging
 
10
  import traceback
11
  import uuid
 
12
  from tempfile import NamedTemporaryFile
13
  import streamlit as st
 
14
 
 
15
  from ui import render_ui
 
16
  from pipeline.video_pipeline import (
17
  stage1_create_transparent_video,
18
  stage2_composite_background,
19
  setup_t4_environment,
20
  check_gpu
21
  )
 
22
  from models.model_loaders import load_sam2, load_matanyone
23
 
 
24
  APP_NAME = "Advanced Video Background Replacer"
25
  LOG_FILE = "/tmp/app.log"
26
  LOG_MAX_BYTES = 5 * 1024 * 1024
27
  LOG_BACKUPS = 5
28
 
 
29
  def setup_logging(level: int = logging.INFO) -> logging.Logger:
30
  logger = logging.getLogger(APP_NAME)
31
  logger.setLevel(level)
32
  logger.propagate = False
33
+ # Remove previous handlers (Streamlit reruns)
34
  for h in list(logger.handlers):
35
  logger.removeHandler(h)
 
36
  ch = logging.StreamHandler(sys.stdout)
37
  ch.setLevel(level)
38
  ch.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
 
39
  fh = logging.handlers.RotatingFileHandler(
40
  LOG_FILE, maxBytes=LOG_MAX_BYTES, backupCount=LOG_BACKUPS, encoding="utf-8"
41
  )
 
47
 
48
  logger = setup_logging()
49
 
 
50
  def custom_excepthook(type, value, tb):
51
  logger.error(f"Unhandled: {type.__name__}: {value}\n{''.join(traceback.format_tb(tb))}", exc_info=True)
52
  sys.excepthook = custom_excepthook
53
 
54
+ # Only load once
55
  sam2_predictor = load_sam2()
56
  matanyone_processor = load_matanyone()
57
 
 
58
  def initialize_session_state():
59
  defaults = {
60
  'uploaded_video': None,
 
73
  'last_error': None,
74
  'log_level_name': 'INFO',
75
  'auto_refresh_logs': False,
76
+ 'log_tail_lines': 400,
77
+ 'generated_bg': None,
78
  }
79
  for k, v in defaults.items():
80
  if k not in st.session_state:
 
82
  if st.session_state.gpu_available is None:
83
  st.session_state.gpu_available = check_gpu(logger)
84
 
85
+ def process_video(uploaded_video, background, bg_type):
 
 
 
 
 
 
 
 
 
 
 
 
86
  run_id = uuid.uuid4().hex[:8]
87
  logger.info("=" * 80)
88
+ logger.info(f"[RUN {run_id}] VIDEO PROCESSING STARTED at {time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())}")
89
  logger.info(f"[RUN {run_id}] Video size={len(uploaded_video.read()) / 1e6:.2f}MB, BG type={bg_type}")
90
+ uploaded_video.seek(0)
 
91
  st.session_state.processing = True
92
  st.session_state.processed_video_bytes = None
93
  st.session_state.last_error = None
94
  t0 = time.time()
95
  try:
 
96
  suffix = Path(uploaded_video.name).suffix or ".mp4"
97
  with NamedTemporaryFile(delete=False, suffix=suffix) as tmp_vid:
98
  uploaded_video.seek(0)
99
  tmp_vid.write(uploaded_video.read())
100
  tmp_vid_path = tmp_vid.name
101
  logger.info(f"[RUN {run_id}] Temporary video path: {tmp_vid_path}")
 
102
  transparent_path, audio_path = stage1_create_transparent_video(
103
  tmp_vid_path,
104
  sam2_predictor=sam2_predictor,
 
107
  if not transparent_path or not os.path.exists(transparent_path):
108
  raise RuntimeError("Stage 1 failed: Transparent video not created")
109
  logger.info(f"[RUN {run_id}] Stage 1 completed: Transparent path={transparent_path}, Audio path={audio_path}")
110
+ # For "color", background is "#RRGGBB", for others it's PIL.Image
 
111
  final_path = stage2_composite_background(
112
  transparent_path,
113
  audio_path,
 
117
  if not final_path or not os.path.exists(final_path):
118
  raise RuntimeError("Stage 2 failed: Final video not created")
119
  logger.info(f"[RUN {run_id}] Stage 2 completed: Final path={final_path}")
 
120
  with open(final_path, 'rb') as f:
121
  st.session_state.processed_video_bytes = f.read()
122
  total = time.time() - t0
 
133
  st.session_state.processing = False
134
  logger.info(f"[RUN {run_id}] Processing finished")
135
 
 
136
  def main():
137
  try:
138
  st.set_page_config(