Spaces:
Sleeping
Sleeping
File size: 10,371 Bytes
d0a935f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | import streamlit as st
import os
import torch
import numpy as np
from PIL import Image
import cv2
import sys
import time
import download_model
# Ensure src directory is in path for local imports
sys.path.append(os.path.dirname(__file__))
from inference import VisionExtractPipeline
# Page configuration
st.set_page_config(
page_title="VisionExtract - Subject Isolation",
page_icon="π―",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for premium look
st.markdown("""
<style>
.main {
background-color: #0e1117;
}
.stButton>button {
width: 100%;
border-radius: 5px;
height: 3em;
background-color: #ff4b4b;
color: white;
font-weight: bold;
border: none;
}
.stButton>button:hover {
background-color: #ff3333;
border: none;
}
.upload-text {
color: #ccd6f6;
font-size: 1.2rem;
text-align: center;
margin-bottom: 2rem;
}
.title-text {
background: linear-gradient(90deg, #ff4b4b, #ff8a8a);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 800;
font-size: 3rem;
margin-bottom: 0px;
}
</style>
""", unsafe_allow_html=True)
def main():
# Sidebar
st.sidebar.title("Configuration")
checkpoint_dir = "checkpoints"
available_checkpoints = []
if os.path.exists(checkpoint_dir):
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
# Sort to put best_model.pth first, then epoch-descending
checkpoints.sort(key=lambda x: (x != "best_model.pth", -int(x.split('_')[-1].split('.')[0]) if 'epoch' in x else 0))
available_checkpoints = checkpoints
if available_checkpoints:
selected_checkpoint = st.sidebar.selectbox("Select Model Checkpoint", available_checkpoints)
model_path = os.path.join(checkpoint_dir, selected_checkpoint)
else:
st.sidebar.warning("No checkpoints found in 'checkpoints/' directory.")
model_path = None
device = st.sidebar.radio("Device", ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"])
st.sidebar.markdown("---")
st.sidebar.markdown("### πΌοΈ Background Style")
bg_options = {
"Deep Black": "black",
"Modern Office": "docs/images/backgrounds/office.png",
"Lush Nature": "docs/images/backgrounds/nature.png",
"Photo Studio": "docs/images/backgrounds/studio.png",
"Soft Blur": "blur"
}
selected_bg = st.sidebar.selectbox("Virtual Background", list(bg_options.keys()))
st.sidebar.markdown("---")
st.sidebar.markdown("### π¬ Architecture: ResNet-UNet")
st.sidebar.caption("High-performance segmentation with pre-trained ResNet34 backbone for precise subject isolation.")
# --- Header ---
st.markdown('<h1 class="gradient-text">VisionExtract AI</h1>', unsafe_allow_html=True)
st.markdown('<p class="sub-text">Intelligent Subject Isolation & Background Extraction</p>', unsafe_allow_html=True)
# --- Tabs ---
tab_extract, tab_tech = st.tabs(["β¨ Extraction Engine", "π Technical Dashboard"])
with tab_extract:
# --- Upload Logic ---
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
uploaded_files = st.file_uploader("Drop images here (Multiple supported)", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
st.markdown('</div>', unsafe_allow_html=True)
if uploaded_files:
st.write(f"π **{len(uploaded_files)}** files queued for isolation.")
# Action Bar
col_btn, col_spacer = st.columns([1, 4])
process_all = col_btn.button("β¨ START EXTRACTION")
if process_all:
# Initialize Pipeline Once (Standard 256 mode)
pipeline = VisionExtractPipeline(model_path=model_path, device=device, image_size=256)
def apply_background(img_np, mask_np, bg_type):
h, w = img_np.shape[:2]
if bg_type == "black":
return (img_np * mask_np[:, :, None]).astype(np.uint8)
elif bg_type == "blur":
background = cv2.GaussianBlur(img_np, (21, 21), 0)
else:
if os.path.exists(bg_options[bg_type]):
background = cv2.imread(bg_options[bg_type])
background = cv2.cvtColor(background, cv2.COLOR_BGR2RGB)
background = cv2.resize(background, (w, h))
else:
return (img_np * mask_np[:, :, None]).astype(np.uint8)
# Alpha Blending with soft-mask for smooth matting
mask_3d = mask_np[:, :, None]
blended = (img_np * mask_3d + background * (1 - mask_3d)).astype(np.uint8)
return blended
# Progress handling
progress_bar = st.progress(0)
status_text = st.empty()
# Grid Display
results_container = st.container()
for i, uploaded_file in enumerate(uploaded_files):
start_time = time.time()
status_text.text(f"Processing: {uploaded_file.name}...")
# Image Load
image = Image.open(uploaded_file)
temp_path = f"temp_{i}.png"
image.save(temp_path)
try:
# Standard Pipeline (No aggressive thinning)
isolated_black, soft_mask = pipeline.full_pipeline(
temp_path,
save=False,
display=False
)
# Apply selected background
final_output = apply_background(np.array(image), soft_mask, selected_bg)
inf_time = time.time() - start_time
# Display Result Card
with results_container:
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
st.markdown(f"#### π·οΈ Output: {uploaded_file.name}")
c1, c2, c3 = st.columns([1, 1, 0.5])
with c1:
st.image(image, caption="Original", use_container_width=True)
with c2:
st.image(final_output, caption=f"Result ({selected_bg})", use_container_width=True)
with c3:
st.markdown(f"""
<div class="metric-box">
<span class="metric-value">β±οΈ {inf_time:.2f}s</span>
<span class="metric-label">Inference</span>
</div>
""", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
# Download
buf = cv2.imencode('.png', cv2.cvtColor(final_output, cv2.COLOR_RGB2BGR))[1].tobytes()
st.download_button(
label="Download PNG",
data=buf,
file_name=f"visionextract_{uploaded_file.name}",
mime="image/png",
key=f"dl_{i}",
use_container_width=True
)
st.markdown('</div>', unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
except Exception as e:
st.error(f"Error on {uploaded_file.name}: {e}")
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
# Update progress
progress_bar.progress((i + 1) / len(uploaded_files))
status_text.success("π Batch Processing Complete!")
st.balloons()
# --- Technical Dashboard ---
with tab_tech:
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
st.markdown("### π Model Performance Metrics")
m1, m2, m3, m4 = st.columns(4)
m1.metric("Avg. IoU", "0.621", "+0.02")
m2.metric("Dice Score", "0.756", "+0.01")
m3.metric("Pixel Accuracy", "90.2%", "+0.5%")
m4.metric("Inf. Speed", "0.15s", "-0.05s")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
st.markdown("### ποΈ Architecture Overview")
st.info("**Encoder:** ResNet34 (ImageNet Pre-trained)\n\n**Decoder:** Symmetric UNet with skip-connections and Bilinear Upsampling.\n\n**Pipeline:** Standardized Aspect-Ratio Aware Inference (256px Base).")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('<div class="glass-card">', unsafe_allow_html=True)
st.markdown("### π Showcase Readiness")
st.success("- [x] Robust Multi-image Batch Processing\n- [x] Standard Linear Up-scaling Matting\n- [x] Dynamic Virtual Background Replacement\n- [x] Optimized Performance for Final Demo")
st.markdown('</div>', unsafe_allow_html=True)
if __name__ == "__main__":
main()
|