Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import sys | |
| import time | |
| import download_model | |
| # Ensure src directory is in path for local imports | |
| sys.path.append(os.path.dirname(__file__)) | |
| from inference import VisionExtractPipeline | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="VisionExtract - Subject Isolation", | |
| page_icon="π―", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for premium look | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background-color: #0e1117; | |
| } | |
| .stButton>button { | |
| width: 100%; | |
| border-radius: 5px; | |
| height: 3em; | |
| background-color: #ff4b4b; | |
| color: white; | |
| font-weight: bold; | |
| border: none; | |
| } | |
| .stButton>button:hover { | |
| background-color: #ff3333; | |
| border: none; | |
| } | |
| .upload-text { | |
| color: #ccd6f6; | |
| font-size: 1.2rem; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .title-text { | |
| background: linear-gradient(90deg, #ff4b4b, #ff8a8a); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-weight: 800; | |
| font-size: 3rem; | |
| margin-bottom: 0px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def main(): | |
| # Sidebar | |
| st.sidebar.title("Configuration") | |
| checkpoint_dir = "checkpoints" | |
| available_checkpoints = [] | |
| if os.path.exists(checkpoint_dir): | |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")] | |
| # Sort to put best_model.pth first, then epoch-descending | |
| checkpoints.sort(key=lambda x: (x != "best_model.pth", -int(x.split('_')[-1].split('.')[0]) if 'epoch' in x else 0)) | |
| available_checkpoints = checkpoints | |
| if available_checkpoints: | |
| selected_checkpoint = st.sidebar.selectbox("Select Model Checkpoint", available_checkpoints) | |
| model_path = os.path.join(checkpoint_dir, selected_checkpoint) | |
| else: | |
| st.sidebar.warning("No checkpoints found in 'checkpoints/' directory.") | |
| model_path = None | |
| device = st.sidebar.radio("Device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### πΌοΈ Background Style") | |
| bg_options = { | |
| "Deep Black": "black", | |
| "Modern Office": "docs/images/backgrounds/office.png", | |
| "Lush Nature": "docs/images/backgrounds/nature.png", | |
| "Photo Studio": "docs/images/backgrounds/studio.png", | |
| "Soft Blur": "blur" | |
| } | |
| selected_bg = st.sidebar.selectbox("Virtual Background", list(bg_options.keys())) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### π¬ Architecture: ResNet-UNet") | |
| st.sidebar.caption("High-performance segmentation with pre-trained ResNet34 backbone for precise subject isolation.") | |
| # --- Header --- | |
| st.markdown('<h1 class="gradient-text">VisionExtract AI</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="sub-text">Intelligent Subject Isolation & Background Extraction</p>', unsafe_allow_html=True) | |
| # --- Tabs --- | |
| tab_extract, tab_tech = st.tabs(["β¨ Extraction Engine", "π Technical Dashboard"]) | |
| with tab_extract: | |
| # --- Upload Logic --- | |
| st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
| uploaded_files = st.file_uploader("Drop images here (Multiple supported)", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| if uploaded_files: | |
| st.write(f"π **{len(uploaded_files)}** files queued for isolation.") | |
| # Action Bar | |
| col_btn, col_spacer = st.columns([1, 4]) | |
| process_all = col_btn.button("β¨ START EXTRACTION") | |
| if process_all: | |
| # Initialize Pipeline Once (Standard 256 mode) | |
| pipeline = VisionExtractPipeline(model_path=model_path, device=device, image_size=256) | |
| def apply_background(img_np, mask_np, bg_type): | |
| h, w = img_np.shape[:2] | |
| if bg_type == "black": | |
| return (img_np * mask_np[:, :, None]).astype(np.uint8) | |
| elif bg_type == "blur": | |
| background = cv2.GaussianBlur(img_np, (21, 21), 0) | |
| else: | |
| if os.path.exists(bg_options[bg_type]): | |
| background = cv2.imread(bg_options[bg_type]) | |
| background = cv2.cvtColor(background, cv2.COLOR_BGR2RGB) | |
| background = cv2.resize(background, (w, h)) | |
| else: | |
| return (img_np * mask_np[:, :, None]).astype(np.uint8) | |
| # Alpha Blending with soft-mask for smooth matting | |
| mask_3d = mask_np[:, :, None] | |
| blended = (img_np * mask_3d + background * (1 - mask_3d)).astype(np.uint8) | |
| return blended | |
| # Progress handling | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Grid Display | |
| results_container = st.container() | |
| for i, uploaded_file in enumerate(uploaded_files): | |
| start_time = time.time() | |
| status_text.text(f"Processing: {uploaded_file.name}...") | |
| # Image Load | |
| image = Image.open(uploaded_file) | |
| temp_path = f"temp_{i}.png" | |
| image.save(temp_path) | |
| try: | |
| # Standard Pipeline (No aggressive thinning) | |
| isolated_black, soft_mask = pipeline.full_pipeline( | |
| temp_path, | |
| save=False, | |
| display=False | |
| ) | |
| # Apply selected background | |
| final_output = apply_background(np.array(image), soft_mask, selected_bg) | |
| inf_time = time.time() - start_time | |
| # Display Result Card | |
| with results_container: | |
| st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
| st.markdown(f"#### π·οΈ Output: {uploaded_file.name}") | |
| c1, c2, c3 = st.columns([1, 1, 0.5]) | |
| with c1: | |
| st.image(image, caption="Original", use_container_width=True) | |
| with c2: | |
| st.image(final_output, caption=f"Result ({selected_bg})", use_container_width=True) | |
| with c3: | |
| st.markdown(f""" | |
| <div class="metric-box"> | |
| <span class="metric-value">β±οΈ {inf_time:.2f}s</span> | |
| <span class="metric-label">Inference</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Download | |
| buf = cv2.imencode('.png', cv2.cvtColor(final_output, cv2.COLOR_RGB2BGR))[1].tobytes() | |
| st.download_button( | |
| label="Download PNG", | |
| data=buf, | |
| file_name=f"visionextract_{uploaded_file.name}", | |
| mime="image/png", | |
| key=f"dl_{i}", | |
| use_container_width=True | |
| ) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error on {uploaded_file.name}: {e}") | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| # Update progress | |
| progress_bar.progress((i + 1) / len(uploaded_files)) | |
| status_text.success("π Batch Processing Complete!") | |
| st.balloons() | |
| # --- Technical Dashboard --- | |
| with tab_tech: | |
| st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
| st.markdown("### π Model Performance Metrics") | |
| m1, m2, m3, m4 = st.columns(4) | |
| m1.metric("Avg. IoU", "0.621", "+0.02") | |
| m2.metric("Dice Score", "0.756", "+0.01") | |
| m3.metric("Pixel Accuracy", "90.2%", "+0.5%") | |
| m4.metric("Inf. Speed", "0.15s", "-0.05s") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
| st.markdown("### ποΈ Architecture Overview") | |
| st.info("**Encoder:** ResNet34 (ImageNet Pre-trained)\n\n**Decoder:** Symmetric UNet with skip-connections and Bilinear Upsampling.\n\n**Pipeline:** Standardized Aspect-Ratio Aware Inference (256px Base).") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('<div class="glass-card">', unsafe_allow_html=True) | |
| st.markdown("### π Showcase Readiness") | |
| st.success("- [x] Robust Multi-image Batch Processing\n- [x] Standard Linear Up-scaling Matting\n- [x] Dynamic Virtual Background Replacement\n- [x] Optimized Performance for Final Demo") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |