import streamlit as st import os import torch import numpy as np from PIL import Image import cv2 import sys import time import download_model # Ensure src directory is in path for local imports sys.path.append(os.path.dirname(__file__)) from inference import VisionExtractPipeline # Page configuration st.set_page_config( page_title="VisionExtract - Subject Isolation", page_icon="🎯", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for premium look st.markdown(""" """, unsafe_allow_html=True) def main(): # Sidebar st.sidebar.title("Configuration") checkpoint_dir = "checkpoints" available_checkpoints = [] if os.path.exists(checkpoint_dir): checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")] # Sort to put best_model.pth first, then epoch-descending checkpoints.sort(key=lambda x: (x != "best_model.pth", -int(x.split('_')[-1].split('.')[0]) if 'epoch' in x else 0)) available_checkpoints = checkpoints if available_checkpoints: selected_checkpoint = st.sidebar.selectbox("Select Model Checkpoint", available_checkpoints) model_path = os.path.join(checkpoint_dir, selected_checkpoint) else: st.sidebar.warning("No checkpoints found in 'checkpoints/' directory.") model_path = None device = st.sidebar.radio("Device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) st.sidebar.markdown("---") st.sidebar.markdown("### 🖼️ Background Style") bg_options = { "Deep Black": "black", "Modern Office": "docs/images/backgrounds/office.png", "Lush Nature": "docs/images/backgrounds/nature.png", "Photo Studio": "docs/images/backgrounds/studio.png", "Soft Blur": "blur" } selected_bg = st.sidebar.selectbox("Virtual Background", list(bg_options.keys())) st.sidebar.markdown("---") st.sidebar.markdown("### 🔬 Architecture: ResNet-UNet") st.sidebar.caption("High-performance segmentation with pre-trained ResNet34 backbone for precise subject isolation.") # --- Header --- st.markdown('

VisionExtract AI

', unsafe_allow_html=True) st.markdown('

Intelligent Subject Isolation & Background Extraction

', unsafe_allow_html=True) # --- Tabs --- tab_extract, tab_tech = st.tabs(["✨ Extraction Engine", "📊 Technical Dashboard"]) with tab_extract: # --- Upload Logic --- st.markdown('
', unsafe_allow_html=True) uploaded_files = st.file_uploader("Drop images here (Multiple supported)", type=["jpg", "jpeg", "png"], accept_multiple_files=True) st.markdown('
', unsafe_allow_html=True) if uploaded_files: st.write(f"📂 **{len(uploaded_files)}** files queued for isolation.") # Action Bar col_btn, col_spacer = st.columns([1, 4]) process_all = col_btn.button("✨ START EXTRACTION") if process_all: # Initialize Pipeline Once (Standard 256 mode) pipeline = VisionExtractPipeline(model_path=model_path, device=device, image_size=256) def apply_background(img_np, mask_np, bg_type): h, w = img_np.shape[:2] if bg_type == "black": return (img_np * mask_np[:, :, None]).astype(np.uint8) elif bg_type == "blur": background = cv2.GaussianBlur(img_np, (21, 21), 0) else: if os.path.exists(bg_options[bg_type]): background = cv2.imread(bg_options[bg_type]) background = cv2.cvtColor(background, cv2.COLOR_BGR2RGB) background = cv2.resize(background, (w, h)) else: return (img_np * mask_np[:, :, None]).astype(np.uint8) # Alpha Blending with soft-mask for smooth matting mask_3d = mask_np[:, :, None] blended = (img_np * mask_3d + background * (1 - mask_3d)).astype(np.uint8) return blended # Progress handling progress_bar = st.progress(0) status_text = st.empty() # Grid Display results_container = st.container() for i, uploaded_file in enumerate(uploaded_files): start_time = time.time() status_text.text(f"Processing: {uploaded_file.name}...") # Image Load image = Image.open(uploaded_file) temp_path = f"temp_{i}.png" image.save(temp_path) try: # Standard Pipeline (No aggressive thinning) isolated_black, soft_mask = pipeline.full_pipeline( temp_path, save=False, display=False ) # Apply selected background final_output = apply_background(np.array(image), soft_mask, selected_bg) inf_time = time.time() - start_time # Display Result Card with results_container: st.markdown('
', unsafe_allow_html=True) st.markdown(f"#### 🏷️ Output: {uploaded_file.name}") c1, c2, c3 = st.columns([1, 1, 0.5]) with c1: st.image(image, caption="Original", use_container_width=True) with c2: st.image(final_output, caption=f"Result ({selected_bg})", use_container_width=True) with c3: st.markdown(f"""
⏱️ {inf_time:.2f}s Inference
""", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # Download buf = cv2.imencode('.png', cv2.cvtColor(final_output, cv2.COLOR_RGB2BGR))[1].tobytes() st.download_button( label="Download PNG", data=buf, file_name=f"visionextract_{uploaded_file.name}", mime="image/png", key=f"dl_{i}", use_container_width=True ) st.markdown('
', unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) except Exception as e: st.error(f"Error on {uploaded_file.name}: {e}") finally: if os.path.exists(temp_path): os.remove(temp_path) # Update progress progress_bar.progress((i + 1) / len(uploaded_files)) status_text.success("🎉 Batch Processing Complete!") st.balloons() # --- Technical Dashboard --- with tab_tech: st.markdown('
', unsafe_allow_html=True) st.markdown("### 📊 Model Performance Metrics") m1, m2, m3, m4 = st.columns(4) m1.metric("Avg. IoU", "0.621", "+0.02") m2.metric("Dice Score", "0.756", "+0.01") m3.metric("Pixel Accuracy", "90.2%", "+0.5%") m4.metric("Inf. Speed", "0.15s", "-0.05s") st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.markdown("### 🏗️ Architecture Overview") st.info("**Encoder:** ResNet34 (ImageNet Pre-trained)\n\n**Decoder:** Symmetric UNet with skip-connections and Bilinear Upsampling.\n\n**Pipeline:** Standardized Aspect-Ratio Aware Inference (256px Base).") st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.markdown("### 🚀 Showcase Readiness") st.success("- [x] Robust Multi-image Batch Processing\n- [x] Standard Linear Up-scaling Matting\n- [x] Dynamic Virtual Background Replacement\n- [x] Optimized Performance for Final Demo") st.markdown('
', unsafe_allow_html=True) if __name__ == "__main__": main()