MogensR commited on
Commit
77bba4f
Β·
verified Β·
1 Parent(s): c933bf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -534
app.py CHANGED
@@ -1,28 +1,22 @@
 
 
 
 
 
 
1
  import streamlit as st
2
- import os
3
  import sys
4
- import tempfile
5
- import time
6
- import shutil
7
- import gc
8
  from pathlib import Path
9
- import cv2
10
- import numpy as np
11
  from PIL import Image
12
- import logging
13
- import base64
14
- from io import BytesIO
15
- import torch
16
- from contextlib import contextmanager
17
 
18
  # Add project root to path
19
  sys.path.append(str(Path(__file__).parent.absolute()))
20
 
21
- # Configure logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
 
25
- # Persistent temp dir (survives beyond TemporaryDirectory scope)
26
  TMP_DIR = Path("tmp")
27
  TMP_DIR.mkdir(parents=True, exist_ok=True)
28
 
@@ -34,403 +28,7 @@
34
  initial_sidebar_state="expanded"
35
  )
36
 
37
- # Memory management utilities
38
- @contextmanager
39
- def torch_memory_manager():
40
- """Context manager for CUDA memory cleanup."""
41
- try:
42
- yield
43
- finally:
44
- if torch.cuda.is_available():
45
- torch.cuda.empty_cache()
46
- gc.collect()
47
-
48
- def clear_model_cache():
49
- """Clear all cached models and free memory."""
50
- if hasattr(st, 'cache_resource'):
51
- st.cache_resource.clear()
52
- if torch.cuda.is_available():
53
- torch.cuda.empty_cache()
54
- gc.collect()
55
- logger.info("Model cache cleared")
56
-
57
- def get_memory_usage():
58
- """Get current memory usage statistics."""
59
- memory_info = {}
60
- if torch.cuda.is_available():
61
- memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
62
- memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
63
- memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
64
- torch.cuda.memory_allocated()) / 1e9
65
- import psutil
66
- memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
67
- memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
68
- return memory_info
69
-
70
- # Lazy model loading
71
- @st.cache_resource(show_spinner=False)
72
- def load_sam2_predictor():
73
- """Lazy load SAM2 image predictor only when needed."""
74
- try:
75
- logger.info("Loading SAM2 image predictor...")
76
- from sam2.build_sam import build_sam2
77
- from sam2.sam2_image_predictor import SAM2ImagePredictor
78
-
79
- checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
80
- model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
81
-
82
- if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
83
- logger.warning("Local checkpoints not found, using Hugging Face...")
84
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
85
- else:
86
- memory_info = get_memory_usage()
87
- if memory_info.get('gpu_free', 0) < 4.0:
88
- logger.warning("Limited GPU memory, using smaller SAM2 model...")
89
- try:
90
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
91
- except:
92
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
93
- else:
94
- predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint_path))
95
-
96
- logger.info("βœ… SAM2 image predictor loaded successfully!")
97
- return predictor
98
- except Exception as e:
99
- logger.error(f"Failed to load SAM2 predictor: {e}")
100
- st.error(f"❌ Failed to load SAM2: {e}")
101
- return None
102
-
103
- @st.cache_resource(show_spinner=False)
104
- def load_matanyone_processor():
105
- """Lazy load MatAnyone processor only when needed."""
106
- try:
107
- logger.info("Loading MatAnyone processor...")
108
- from matanyone import InferenceCore
109
- processor = InferenceCore("PeiqingYang/MatAnyone")
110
- logger.info("βœ… MatAnyone processor loaded successfully!")
111
- return processor
112
- except Exception as e:
113
- logger.error(f"Failed to load MatAnyone: {e}")
114
- st.error(f"❌ Failed to load MatAnyone: {e}")
115
- return None
116
-
117
- def generate_mask_from_video_first_frame(video_path, sam2_predictor):
118
- """Generate mask for the first frame of video using SAM2."""
119
- try:
120
- with torch_memory_manager():
121
- cap = cv2.VideoCapture(video_path)
122
- ret, frame = cap.read()
123
- cap.release()
124
-
125
- if not ret:
126
- st.error("Failed to read video frame")
127
- return None
128
-
129
- # Resize frame if too large to save memory
130
- h, w = frame.shape[:2]
131
- max_size = 1080
132
- if max(h, w) > max_size:
133
- scale = max_size / max(h, w)
134
- new_w, new_h = int(w * scale), int(h * scale)
135
- frame = cv2.resize(frame, (new_w, new_h))
136
- logger.info(f"Resized frame from {w}x{h} to {new_w}x{new_h}")
137
-
138
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
139
-
140
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
141
- sam2_predictor.set_image(frame_rgb)
142
-
143
- # Get center point as default prompt
144
- h, w = frame_rgb.shape[:2]
145
- center_point = np.array([[w//2, h//2]], dtype=np.float32)
146
- center_label = np.array([1], dtype=np.int32)
147
-
148
- masks, scores, logits = sam2_predictor.predict(
149
- point_coords=center_point,
150
- point_labels=center_label,
151
- multimask_output=True
152
- )
153
-
154
- best_mask = masks[np.argmax(scores)]
155
- return best_mask.astype(np.uint8) * 255
156
-
157
- except Exception as e:
158
- st.error(f"Failed to generate mask: {e}")
159
- return None
160
-
161
- def stage1_create_transparent_video(input_file):
162
- """STAGE 1: Create transparent video using SAM2 + MatAnyone."""
163
-
164
- logger.info("Starting Stage 1: Create transparent video")
165
-
166
- memory_info = get_memory_usage()
167
- if memory_info.get('gpu_free', 0) < 2.0:
168
- st.warning("⚠️ Low GPU memory detected. Processing may be slower.")
169
- clear_model_cache()
170
-
171
- try:
172
- progress_bar = st.progress(0)
173
- status_text = st.empty()
174
-
175
- def update_progress(progress, message):
176
- progress = max(0, min(1, progress))
177
- progress_bar.progress(progress)
178
- status_text.text(f"Stage 1: {message} | GPU: {get_memory_usage().get('gpu_allocated', 0):.1f}GB")
179
- logger.info(f"Stage 1 Progress: {progress:.2f} - {message}")
180
-
181
- # Load models
182
- update_progress(0.05, "Loading SAM2 model...")
183
- logger.info("Attempting to load SAM2 predictor...")
184
- sam2_predictor = load_sam2_predictor()
185
- if sam2_predictor is None:
186
- logger.error("SAM2 predictor failed to load")
187
- st.error("❌ Failed to load SAM2 model")
188
- return None
189
- logger.info("SAM2 predictor loaded successfully")
190
-
191
- update_progress(0.1, "Loading MatAnyone model...")
192
- logger.info("Attempting to load MatAnyone processor...")
193
- matanyone_processor = load_matanyone_processor()
194
- if matanyone_processor is None:
195
- logger.error("MatAnyone processor failed to load")
196
- st.error("❌ Failed to load MatAnyone model")
197
- return None
198
- logger.info("MatAnyone processor loaded successfully")
199
-
200
- # Process video to create transparent version
201
- with tempfile.TemporaryDirectory() as temp_dir:
202
- temp_dir = Path(temp_dir)
203
- input_path = str(temp_dir / "input.mp4")
204
-
205
- # Save input video
206
- with open(input_path, "wb") as f:
207
- f.write(input_file.getvalue())
208
-
209
- update_progress(0.2, "Generating segmentation mask...")
210
-
211
- # Generate mask using SAM2
212
- with torch_memory_manager():
213
- mask = generate_mask_from_video_first_frame(input_path, sam2_predictor)
214
-
215
- if mask is None:
216
- return None
217
-
218
- mask_path = str(temp_dir / "mask.png")
219
- cv2.imwrite(mask_path, mask)
220
-
221
- update_progress(0.4, "Creating transparent video with MatAnyone...")
222
-
223
- # Process with MatAnyone to get foreground and alpha
224
- try:
225
- with torch_memory_manager():
226
- foreground_path, alpha_path = matanyone_processor.process_video(
227
- input_path=input_path,
228
- mask_path=mask_path,
229
- output_path=str(temp_dir),
230
- max_size=720 # Limit resolution for memory efficiency
231
- )
232
-
233
- update_progress(0.8, "Creating transparent .mov file...")
234
-
235
- # Create transparent video (.mov with alpha channel)
236
- transparent_path = create_transparent_mov(foreground_path, alpha_path, temp_dir)
237
-
238
- if transparent_path and os.path.exists(transparent_path):
239
- # Copy to persistent location
240
- persist_path = TMP_DIR / "transparent_video.mov"
241
- shutil.copyfile(transparent_path, persist_path)
242
-
243
- update_progress(1.0, "Transparent video created!")
244
- time.sleep(0.5)
245
- return str(persist_path)
246
- else:
247
- st.error("Failed to create transparent video")
248
- return None
249
-
250
- except Exception as e:
251
- st.error(f"MatAnyone processing failed: {e}")
252
- return None
253
-
254
- except Exception as e:
255
- logger.error(f"Error in Stage 1 processing: {str(e)}", exc_info=True)
256
- st.error(f"❌ Stage 1 failed: {str(e)}")
257
-
258
- # Show additional debug info
259
- try:
260
- memory_info = get_memory_usage()
261
- st.info(f"Memory at failure - GPU: {memory_info.get('gpu_allocated', 0):.1f}GB, RAM: {memory_info.get('ram_used', 0):.1f}GB")
262
- except:
263
- pass
264
-
265
- return None
266
- finally:
267
- logger.info("Stage 1 cleanup starting...")
268
- if torch.cuda.is_available():
269
- torch.cuda.empty_cache()
270
- gc.collect()
271
- logger.info("Stage 1 cleanup completed")
272
-
273
- def create_transparent_mov(foreground_path, alpha_path, temp_dir):
274
- """Create a .mov file with alpha channel from foreground and alpha videos."""
275
- try:
276
- output_path = str(temp_dir / "transparent.mov")
277
-
278
- # Read videos
279
- fg_cap = cv2.VideoCapture(foreground_path)
280
- alpha_cap = cv2.VideoCapture(alpha_path)
281
-
282
- # Get video properties
283
- fps = int(fg_cap.get(cv2.CAP_PROP_FPS)) or 30
284
- width = int(fg_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
285
- height = int(fg_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
286
-
287
- # Use PNG codec for alpha channel support
288
- fourcc = cv2.VideoWriter_fourcc(*'png ') # PNG codec supports alpha
289
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), True)
290
-
291
- frame_count = 0
292
- while True:
293
- ret_fg, fg_frame = fg_cap.read()
294
- ret_alpha, alpha_frame = alpha_cap.read()
295
-
296
- if not ret_fg or not ret_alpha:
297
- break
298
-
299
- # Convert alpha to single channel if needed
300
- if len(alpha_frame.shape) == 3:
301
- alpha_frame = cv2.cvtColor(alpha_frame, cv2.COLOR_BGR2GRAY)
302
-
303
- # Create RGBA frame
304
- rgba_frame = np.zeros((height, width, 4), dtype=np.uint8)
305
- rgba_frame[:, :, :3] = fg_frame # RGB channels
306
- rgba_frame[:, :, 3] = alpha_frame # Alpha channel
307
-
308
- # Convert RGBA to BGRA for OpenCV
309
- bgra_frame = cv2.cvtColor(rgba_frame, cv2.COLOR_RGBA2BGRA)
310
- out.write(bgra_frame)
311
-
312
- frame_count += 1
313
- if frame_count % 10 == 0:
314
- gc.collect()
315
-
316
- fg_cap.release()
317
- alpha_cap.release()
318
- out.release()
319
-
320
- return output_path if os.path.exists(output_path) else None
321
-
322
- except Exception as e:
323
- logger.error(f"Failed to create transparent MOV: {e}")
324
- return None
325
-
326
- def stage2_composite_background(transparent_video_path, background, bg_type):
327
- """STAGE 2: Composite transparent video with new background."""
328
-
329
- try:
330
- progress_bar = st.progress(0)
331
- status_text = st.empty()
332
-
333
- def update_progress(progress, message):
334
- progress = max(0, min(1, progress))
335
- progress_bar.progress(progress)
336
- status_text.text(f"Stage 2: {message}")
337
-
338
- with tempfile.TemporaryDirectory() as temp_dir:
339
- temp_dir = Path(temp_dir)
340
-
341
- update_progress(0.2, "Loading transparent video...")
342
-
343
- # Read transparent video
344
- cap = cv2.VideoCapture(transparent_video_path)
345
- fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
346
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
347
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
348
-
349
- # Prepare background
350
- update_progress(0.4, "Preparing background...")
351
-
352
- if bg_type == "image" and background is not None:
353
- bg_array = np.array(background)
354
- if len(bg_array.shape) == 3 and bg_array.shape[2] == 3:
355
- bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGB2BGR)
356
- elif len(bg_array.shape) == 3 and bg_array.shape[2] == 4:
357
- bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGBA2BGR)
358
- bg_resized = cv2.resize(bg_array, (width, height))
359
- elif bg_type == "color":
360
- color_hex = st.session_state.bg_color.lstrip('#')
361
- r = int(color_hex[0:2], 16)
362
- g = int(color_hex[2:4], 16)
363
- b = int(color_hex[4:6], 16)
364
- bg_resized = np.full((height, width, 3), (b, g, r), dtype=np.uint8)
365
- else:
366
- bg_resized = np.full((height, width, 3), (0, 255, 0), dtype=np.uint8)
367
-
368
- # Create output video
369
- output_path = str(temp_dir / "final_output.mp4")
370
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
371
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
372
-
373
- update_progress(0.6, "Compositing frames...")
374
-
375
- frame_count = 0
376
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
377
-
378
- while True:
379
- ret, frame = cap.read()
380
- if not ret:
381
- break
382
-
383
- # Extract alpha channel if present (BGRA format)
384
- if frame.shape[2] == 4:
385
- bgr_frame = frame[:, :, :3]
386
- alpha_channel = frame[:, :, 3]
387
- else:
388
- # Fallback: assume full opacity
389
- bgr_frame = frame
390
- alpha_channel = np.full((height, width), 255, dtype=np.uint8)
391
-
392
- # Normalize alpha to 0-1
393
- alpha_norm = alpha_channel.astype(np.float32) / 255.0
394
- alpha_norm = np.expand_dims(alpha_norm, axis=2)
395
-
396
- # Composite: result = fg * alpha + bg * (1 - alpha)
397
- fg_float = bgr_frame.astype(np.float32)
398
- bg_float = bg_resized.astype(np.float32)
399
-
400
- result = fg_float * alpha_norm + bg_float * (1 - alpha_norm)
401
- result = result.astype(np.uint8)
402
-
403
- out.write(result)
404
- frame_count += 1
405
-
406
- # Update progress
407
- if total_frames > 0 and frame_count % 5 == 0:
408
- progress = 0.6 + 0.3 * (frame_count / total_frames)
409
- update_progress(progress, f"Compositing frame {frame_count}/{total_frames}")
410
-
411
- if frame_count % 10 == 0:
412
- gc.collect()
413
-
414
- cap.release()
415
- out.release()
416
-
417
- if os.path.exists(output_path):
418
- # Copy to persistent location
419
- persist_path = TMP_DIR / "final_video.mp4"
420
- shutil.copyfile(output_path, persist_path)
421
-
422
- update_progress(1.0, "Compositing complete!")
423
- time.sleep(0.5)
424
- return str(persist_path)
425
- else:
426
- return None
427
-
428
- except Exception as e:
429
- logger.error(f"Error in Stage 2 compositing: {str(e)}", exc_info=True)
430
- st.error(f"Stage 2 failed: {str(e)}")
431
- return None
432
-
433
- # UI Functions (simplified for two-stage approach)
434
  def add_logo():
435
  st.markdown(
436
  """
@@ -446,6 +44,7 @@ def add_logo():
446
  )
447
 
448
  def show_memory_info():
 
449
  memory_info = get_memory_usage()
450
  with st.sidebar:
451
  st.markdown("### 🧠 Memory Usage")
@@ -455,21 +54,14 @@ def show_memory_info():
455
  st.metric("RAM Usage", f"{memory_info['ram_used']:.1f}GB",
456
  f"Available: {memory_info['ram_available']:.1f}GB")
457
 
458
- # Test model loading
459
  if st.button("πŸ§ͺ Test Models", help="Test if SAM2 and MatAnyone can load"):
460
  with st.spinner("Testing model loading..."):
461
  try:
462
  sam2_test = load_sam2_predictor()
463
- if sam2_test:
464
- st.success("βœ… SAM2 loads successfully")
465
- else:
466
- st.error("❌ SAM2 failed to load")
467
 
468
  matanyone_test = load_matanyone_processor()
469
- if matanyone_test:
470
- st.success("βœ… MatAnyone loads successfully")
471
- else:
472
- st.error("❌ MatAnyone failed to load")
473
  except Exception as e:
474
  st.error(f"Model test failed: {e}")
475
 
@@ -479,26 +71,24 @@ def show_memory_info():
479
  st.experimental_rerun()
480
 
481
  def initialize_session_state():
482
- if 'uploaded_video' not in st.session_state:
483
- st.session_state.uploaded_video = None
484
- if 'bg_image' not in st.session_state:
485
- st.session_state.bg_image = None
486
- if 'bg_image_info' not in st.session_state:
487
- st.session_state.bg_image_info = None
488
- if 'bg_color' not in st.session_state:
489
- st.session_state.bg_color = "#00FF00"
490
- if 'bg_type' not in st.session_state:
491
- st.session_state.bg_type = "image"
492
- if 'transparent_video_path' not in st.session_state:
493
- st.session_state.transparent_video_path = None
494
- if 'final_video_path' not in st.session_state:
495
- st.session_state.final_video_path = None
496
- if 'processing_stage1' not in st.session_state:
497
- st.session_state.processing_stage1 = False
498
- if 'processing_stage2' not in st.session_state:
499
- st.session_state.processing_stage2 = False
500
 
501
  def handle_video_upload():
 
502
  uploaded = st.file_uploader(
503
  "πŸ“Ή Upload Video",
504
  type=["mp4", "mov", "avi", "mkv"],
@@ -510,11 +100,11 @@ def handle_video_upload():
510
  if file_size_mb > 100:
511
  st.warning(f"⚠️ Large file detected ({file_size_mb:.1f}MB). Processing may take longer.")
512
  st.session_state.uploaded_video = uploaded
513
- # Reset processed videos when new video is uploaded
514
  st.session_state.transparent_video_path = None
515
  st.session_state.final_video_path = None
516
 
517
  def show_video_preview():
 
518
  st.markdown("### Video Preview")
519
  if st.session_state.uploaded_video is not None:
520
  video_bytes = st.session_state.uploaded_video.getvalue()
@@ -522,48 +112,40 @@ def show_video_preview():
522
  st.session_state.uploaded_video.seek(0)
523
 
524
  def handle_background_selection():
 
525
  st.markdown("### Background Options")
526
- bg_type = st.radio(
527
- "Select Background Type:",
528
- ["Image", "Color"],
529
- horizontal=True,
530
- key="bg_type_radio"
531
- )
532
  st.session_state.bg_type = bg_type.lower()
 
533
  if bg_type == "Image":
534
  handle_image_background()
535
  elif bg_type == "Color":
536
  handle_color_background()
537
 
538
  def handle_image_background():
539
- bg_image = st.file_uploader(
540
- "πŸ–ΌοΈ Upload Background Image",
541
- type=["jpg", "png", "jpeg"],
542
- key="bg_image_uploader",
543
- help="Recommended: Images under 5MB for better performance"
544
- )
545
 
546
  if bg_image is not None:
547
  image_size_mb = bg_image.size / (1024 * 1024)
548
  if image_size_mb > 10:
549
- st.warning(f"⚠️ Large image ({image_size_mb:.1f}MB). Consider resizing for better performance.")
550
 
551
  current_file_info = f"{bg_image.name}_{bg_image.size}"
552
  if st.session_state.bg_image_info != current_file_info:
553
  st.session_state.bg_image = Image.open(bg_image)
554
  st.session_state.bg_image_info = current_file_info
555
- # Reset final video when background changes
556
  st.session_state.final_video_path = None
557
 
558
  if st.session_state.bg_image is not None:
559
  st.image(st.session_state.bg_image, caption="Selected Background", use_container_width=True)
560
  else:
561
- if 'bg_image' in st.session_state:
562
- st.session_state.bg_image = None
563
- if 'bg_image_info' in st.session_state:
564
- st.session_state.bg_image_info = None
565
 
566
  def handle_color_background():
 
567
  st.markdown("#### Select a Color")
568
  old_color = st.session_state.get('bg_color', "#00FF00")
569
 
@@ -583,51 +165,43 @@ def handle_color_background():
583
  new_color = st.color_picker("Custom Color", old_color, key="custom_color_picker")
584
  if new_color != old_color:
585
  st.session_state.bg_color = new_color
586
- st.session_state.final_video_path = None # Reset final video
587
  else:
588
  if st.button(name, key=f"color_{name}", use_container_width=True):
589
  st.session_state.bg_color = color
590
- st.session_state.final_video_path = None # Reset final video
591
- st.markdown(
592
- f'<div style="background-color:{color}; height:30px; border-radius:4px; margin-top:-10px;"></div>',
593
- unsafe_allow_html=True
594
- )
595
 
596
  def main():
 
597
  add_logo()
598
-
599
- st.markdown(
600
- """
601
  <div style="text-align: center; margin-bottom: 30px;">
602
  <h1>πŸŽ₯ Video Background Replacer</h1>
603
  <p>Two-Stage Processing: SAM2 + MatAnyone β†’ Transparent β†’ Composite</p>
604
  </div>
605
- """,
606
- unsafe_allow_html=True
607
- )
608
  st.markdown("---")
609
-
610
  initialize_session_state()
611
  show_memory_info()
612
-
613
  col1, col2 = st.columns([1, 1], gap="large")
614
-
 
615
  with col1:
616
  st.header("1. Upload Video")
617
  handle_video_upload()
618
  show_video_preview()
619
 
620
- # STAGE 1: Create Transparent Video
621
  st.markdown('<div class="stage-indicator">STAGE 1: Create Transparent Video</div>', unsafe_allow_html=True)
622
 
623
  stage1_disabled = not st.session_state.uploaded_video or st.session_state.processing_stage1
624
 
625
- if st.button("🎭 Create Transparent Video",
626
- type="primary",
627
- disabled=stage1_disabled,
628
- use_container_width=True,
629
- help="Remove background using SAM2 + MatAnyone AI"):
630
-
631
  with st.spinner("Stage 1: Creating transparent video..."):
632
  st.session_state.processing_stage1 = True
633
  try:
@@ -642,7 +216,7 @@ def main():
642
  st.error(f"❌ Stage 1 Error: {str(e)}")
643
  finally:
644
  st.session_state.processing_stage1 = False
645
-
646
  # Show transparent video result
647
  if st.session_state.get('transparent_video_path'):
648
  st.markdown("#### Transparent Video Result")
@@ -650,54 +224,35 @@ def main():
650
  with open(st.session_state.transparent_video_path, 'rb') as f:
651
  transparent_bytes = f.read()
652
  st.video(transparent_bytes)
653
- st.download_button(
654
- label="πŸ’Ύ Download Transparent Video (.mov)",
655
- data=transparent_bytes,
656
- file_name="transparent_video.mov",
657
- mime="video/quicktime",
658
- use_container_width=True,
659
- help="Download for use in other video editors"
660
- )
661
- file_size_mb = len(transparent_bytes) / (1024 * 1024)
662
- st.caption(f"Transparent video size: {file_size_mb:.1f}MB")
663
  except Exception as e:
664
  st.error(f"Error displaying transparent video: {str(e)}")
665
-
 
666
  with col2:
667
  st.header("2. Background Settings")
668
  handle_background_selection()
669
 
670
- # STAGE 2: Composite with Background
671
  st.markdown('<div class="stage-indicator">STAGE 2: Composite with Background</div>', unsafe_allow_html=True)
672
 
673
  stage2_disabled = (not st.session_state.get('transparent_video_path') or
674
  st.session_state.processing_stage2 or
675
  (st.session_state.bg_type == "image" and not st.session_state.get('bg_image')))
676
 
677
- if st.button("🎬 Composite Final Video",
678
- type="primary",
679
- disabled=stage2_disabled,
680
- use_container_width=True,
681
- help="Combine transparent video with selected background"):
682
-
683
  if st.session_state.bg_type == "image" and not st.session_state.get('bg_image'):
684
  st.error("Please upload a background image first.")
685
  else:
686
  with st.spinner("Stage 2: Compositing with background..."):
687
  st.session_state.processing_stage2 = True
688
  try:
689
- background = None
690
- if st.session_state.bg_type == "image":
691
- background = st.session_state.bg_image
692
- elif st.session_state.bg_type == "color":
693
- background = st.session_state.bg_color
694
-
695
- final_path = stage2_composite_background(
696
- st.session_state.transparent_video_path,
697
- background,
698
- st.session_state.bg_type
699
- )
700
-
701
  if final_path:
702
  st.session_state.final_video_path = final_path
703
  st.success("βœ… Stage 2 Complete: Final video ready!")
@@ -708,7 +263,7 @@ def main():
708
  st.error(f"❌ Stage 2 Error: {str(e)}")
709
  finally:
710
  st.session_state.processing_stage2 = False
711
-
712
  # Show final video result
713
  if st.session_state.get('final_video_path'):
714
  st.markdown("#### Final Video Result")
@@ -716,39 +271,25 @@ def main():
716
  with open(st.session_state.final_video_path, 'rb') as f:
717
  final_bytes = f.read()
718
  st.video(final_bytes)
719
- st.download_button(
720
- label="πŸ’Ύ Download Final Video (.mp4)",
721
- data=final_bytes,
722
- file_name="final_video.mp4",
723
- mime="video/mp4",
724
- use_container_width=True
725
- )
726
- file_size_mb = len(final_bytes) / (1024 * 1024)
727
- st.caption(f"Final video size: {file_size_mb:.1f}MB")
728
  except Exception as e:
729
  st.error(f"Error displaying final video: {str(e)}")
730
-
731
  # Processing tips
732
  with st.expander("πŸ’‘ Two-Stage Processing Tips"):
733
  st.markdown("""
734
  **Stage 1 - Create Transparent Video:**
735
  - Uses SAM2 + MatAnyone AI to remove background
736
- - Creates a .mov file with alpha channel (transparency)
737
  - Only needs to be done once per video
738
- - Download transparent video for use in other editors
739
 
740
  **Stage 2 - Composite Background:**
741
  - Fast compositing with your chosen background
742
- - Can try multiple backgrounds without re-processing
743
- - Change background and re-composite instantly
744
  - Much faster than Stage 1
745
-
746
- **Benefits:**
747
- - **Flexible**: Try different backgrounds easily
748
- - **Efficient**: Reuse transparent video multiple times
749
- - **Professional**: Industry-standard workflow
750
- - **Cacheable**: Save transparent video for future use
751
  """)
752
 
753
  if __name__ == "__main__":
754
- main()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MyAvatar Video Background Replacer - Streamlit UI
4
+ Main interface for two-stage video processing pipeline
5
+ """
6
+
7
  import streamlit as st
 
8
  import sys
 
 
 
 
9
  from pathlib import Path
 
 
10
  from PIL import Image
 
 
 
 
 
11
 
12
  # Add project root to path
13
  sys.path.append(str(Path(__file__).parent.absolute()))
14
 
15
+ # Import processing modules
16
+ from models import load_sam2_predictor, load_matanyone_processor, clear_model_cache, get_memory_usage
17
+ from video_pipeline import stage1_create_transparent_video, stage2_composite_background
18
 
19
+ # Persistent temp dir
20
  TMP_DIR = Path("tmp")
21
  TMP_DIR.mkdir(parents=True, exist_ok=True)
22
 
 
28
  initial_sidebar_state="expanded"
29
  )
30
 
31
+ # Styling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def add_logo():
33
  st.markdown(
34
  """
 
44
  )
45
 
46
  def show_memory_info():
47
+ """Display memory usage in sidebar with model testing."""
48
  memory_info = get_memory_usage()
49
  with st.sidebar:
50
  st.markdown("### 🧠 Memory Usage")
 
54
  st.metric("RAM Usage", f"{memory_info['ram_used']:.1f}GB",
55
  f"Available: {memory_info['ram_available']:.1f}GB")
56
 
 
57
  if st.button("πŸ§ͺ Test Models", help="Test if SAM2 and MatAnyone can load"):
58
  with st.spinner("Testing model loading..."):
59
  try:
60
  sam2_test = load_sam2_predictor()
61
+ st.success("βœ… SAM2 loads successfully") if sam2_test else st.error("❌ SAM2 failed to load")
 
 
 
62
 
63
  matanyone_test = load_matanyone_processor()
64
+ st.success("βœ… MatAnyone loads successfully") if matanyone_test else st.error("❌ MatAnyone failed to load")
 
 
 
65
  except Exception as e:
66
  st.error(f"Model test failed: {e}")
67
 
 
71
  st.experimental_rerun()
72
 
73
  def initialize_session_state():
74
+ """Initialize all session state variables."""
75
+ defaults = {
76
+ 'uploaded_video': None,
77
+ 'bg_image': None,
78
+ 'bg_image_info': None,
79
+ 'bg_color': "#00FF00",
80
+ 'bg_type': "image",
81
+ 'transparent_video_path': None,
82
+ 'final_video_path': None,
83
+ 'processing_stage1': False,
84
+ 'processing_stage2': False
85
+ }
86
+ for key, value in defaults.items():
87
+ if key not in st.session_state:
88
+ st.session_state[key] = value
 
 
 
89
 
90
  def handle_video_upload():
91
+ """Handle video file upload."""
92
  uploaded = st.file_uploader(
93
  "πŸ“Ή Upload Video",
94
  type=["mp4", "mov", "avi", "mkv"],
 
100
  if file_size_mb > 100:
101
  st.warning(f"⚠️ Large file detected ({file_size_mb:.1f}MB). Processing may take longer.")
102
  st.session_state.uploaded_video = uploaded
 
103
  st.session_state.transparent_video_path = None
104
  st.session_state.final_video_path = None
105
 
106
  def show_video_preview():
107
+ """Display uploaded video preview."""
108
  st.markdown("### Video Preview")
109
  if st.session_state.uploaded_video is not None:
110
  video_bytes = st.session_state.uploaded_video.getvalue()
 
112
  st.session_state.uploaded_video.seek(0)
113
 
114
  def handle_background_selection():
115
+ """Handle background type selection."""
116
  st.markdown("### Background Options")
117
+ bg_type = st.radio("Select Background Type:", ["Image", "Color"], horizontal=True, key="bg_type_radio")
 
 
 
 
 
118
  st.session_state.bg_type = bg_type.lower()
119
+
120
  if bg_type == "Image":
121
  handle_image_background()
122
  elif bg_type == "Color":
123
  handle_color_background()
124
 
125
  def handle_image_background():
126
+ """Handle image background upload and preview."""
127
+ bg_image = st.file_uploader("πŸ–ΌοΈ Upload Background Image", type=["jpg", "png", "jpeg"],
128
+ key="bg_image_uploader", help="Recommended: Images under 5MB")
 
 
 
129
 
130
  if bg_image is not None:
131
  image_size_mb = bg_image.size / (1024 * 1024)
132
  if image_size_mb > 10:
133
+ st.warning(f"⚠️ Large image ({image_size_mb:.1f}MB). Consider resizing.")
134
 
135
  current_file_info = f"{bg_image.name}_{bg_image.size}"
136
  if st.session_state.bg_image_info != current_file_info:
137
  st.session_state.bg_image = Image.open(bg_image)
138
  st.session_state.bg_image_info = current_file_info
 
139
  st.session_state.final_video_path = None
140
 
141
  if st.session_state.bg_image is not None:
142
  st.image(st.session_state.bg_image, caption="Selected Background", use_container_width=True)
143
  else:
144
+ st.session_state.bg_image = None
145
+ st.session_state.bg_image_info = None
 
 
146
 
147
  def handle_color_background():
148
+ """Handle solid color background selection."""
149
  st.markdown("#### Select a Color")
150
  old_color = st.session_state.get('bg_color', "#00FF00")
151
 
 
165
  new_color = st.color_picker("Custom Color", old_color, key="custom_color_picker")
166
  if new_color != old_color:
167
  st.session_state.bg_color = new_color
168
+ st.session_state.final_video_path = None
169
  else:
170
  if st.button(name, key=f"color_{name}", use_container_width=True):
171
  st.session_state.bg_color = color
172
+ st.session_state.final_video_path = None
173
+ st.markdown(f'<div style="background-color:{color}; height:30px; border-radius:4px; margin-top:-10px;"></div>',
174
+ unsafe_allow_html=True)
 
 
175
 
176
  def main():
177
+ """Main application entry point."""
178
  add_logo()
179
+
180
+ st.markdown("""
 
181
  <div style="text-align: center; margin-bottom: 30px;">
182
  <h1>πŸŽ₯ Video Background Replacer</h1>
183
  <p>Two-Stage Processing: SAM2 + MatAnyone β†’ Transparent β†’ Composite</p>
184
  </div>
185
+ """, unsafe_allow_html=True)
 
 
186
  st.markdown("---")
187
+
188
  initialize_session_state()
189
  show_memory_info()
190
+
191
  col1, col2 = st.columns([1, 1], gap="large")
192
+
193
+ # LEFT COLUMN: Video Upload & Stage 1
194
  with col1:
195
  st.header("1. Upload Video")
196
  handle_video_upload()
197
  show_video_preview()
198
 
 
199
  st.markdown('<div class="stage-indicator">STAGE 1: Create Transparent Video</div>', unsafe_allow_html=True)
200
 
201
  stage1_disabled = not st.session_state.uploaded_video or st.session_state.processing_stage1
202
 
203
+ if st.button("🎭 Create Transparent Video", type="primary", disabled=stage1_disabled,
204
+ use_container_width=True, help="Remove background using SAM2 + MatAnyone AI"):
 
 
 
 
205
  with st.spinner("Stage 1: Creating transparent video..."):
206
  st.session_state.processing_stage1 = True
207
  try:
 
216
  st.error(f"❌ Stage 1 Error: {str(e)}")
217
  finally:
218
  st.session_state.processing_stage1 = False
219
+
220
  # Show transparent video result
221
  if st.session_state.get('transparent_video_path'):
222
  st.markdown("#### Transparent Video Result")
 
224
  with open(st.session_state.transparent_video_path, 'rb') as f:
225
  transparent_bytes = f.read()
226
  st.video(transparent_bytes)
227
+ st.download_button("πŸ’Ύ Download Transparent Video (.mov)", data=transparent_bytes,
228
+ file_name="transparent_video.mov", mime="video/quicktime",
229
+ use_container_width=True)
230
+ st.caption(f"Size: {len(transparent_bytes) / (1024**2):.1f}MB")
 
 
 
 
 
 
231
  except Exception as e:
232
  st.error(f"Error displaying transparent video: {str(e)}")
233
+
234
+ # RIGHT COLUMN: Background Selection & Stage 2
235
  with col2:
236
  st.header("2. Background Settings")
237
  handle_background_selection()
238
 
 
239
  st.markdown('<div class="stage-indicator">STAGE 2: Composite with Background</div>', unsafe_allow_html=True)
240
 
241
  stage2_disabled = (not st.session_state.get('transparent_video_path') or
242
  st.session_state.processing_stage2 or
243
  (st.session_state.bg_type == "image" and not st.session_state.get('bg_image')))
244
 
245
+ if st.button("🎬 Composite Final Video", type="primary", disabled=stage2_disabled,
246
+ use_container_width=True, help="Combine transparent video with selected background"):
 
 
 
 
247
  if st.session_state.bg_type == "image" and not st.session_state.get('bg_image'):
248
  st.error("Please upload a background image first.")
249
  else:
250
  with st.spinner("Stage 2: Compositing with background..."):
251
  st.session_state.processing_stage2 = True
252
  try:
253
+ background = st.session_state.bg_image if st.session_state.bg_type == "image" else st.session_state.bg_color
254
+ final_path = stage2_composite_background(st.session_state.transparent_video_path,
255
+ background, st.session_state.bg_type)
 
 
 
 
 
 
 
 
 
256
  if final_path:
257
  st.session_state.final_video_path = final_path
258
  st.success("βœ… Stage 2 Complete: Final video ready!")
 
263
  st.error(f"❌ Stage 2 Error: {str(e)}")
264
  finally:
265
  st.session_state.processing_stage2 = False
266
+
267
  # Show final video result
268
  if st.session_state.get('final_video_path'):
269
  st.markdown("#### Final Video Result")
 
271
  with open(st.session_state.final_video_path, 'rb') as f:
272
  final_bytes = f.read()
273
  st.video(final_bytes)
274
+ st.download_button("πŸ’Ύ Download Final Video (.mp4)", data=final_bytes,
275
+ file_name="final_video.mp4", mime="video/mp4", use_container_width=True)
276
+ st.caption(f"Size: {len(final_bytes) / (1024**2):.1f}MB")
 
 
 
 
 
 
277
  except Exception as e:
278
  st.error(f"Error displaying final video: {str(e)}")
279
+
280
  # Processing tips
281
  with st.expander("πŸ’‘ Two-Stage Processing Tips"):
282
  st.markdown("""
283
  **Stage 1 - Create Transparent Video:**
284
  - Uses SAM2 + MatAnyone AI to remove background
285
+ - Creates a .mov file with alpha channel
286
  - Only needs to be done once per video
 
287
 
288
  **Stage 2 - Composite Background:**
289
  - Fast compositing with your chosen background
290
+ - Try multiple backgrounds without re-processing
 
291
  - Much faster than Stage 1
 
 
 
 
 
 
292
  """)
293
 
294
  if __name__ == "__main__":
295
+ main()