Spaces:
Runtime error
Runtime error
File size: 5,632 Bytes
bdb70cc eacd6a2 bdb70cc 38d1816 bdb70cc 38d1816 eacd6a2 1f4e421 eacd6a2 8cf2c66 eacd6a2 1f4e421 eacd6a2 8cf2c66 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 1f4e421 eacd6a2 8cf2c66 eacd6a2 8cf2c66 eacd6a2 8cf2c66 eacd6a2 8cf2c66 eacd6a2 8cf2c66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import sys
import os
import tempfile
import time
from streamlit_option_menu import option_menu
st.set_page_config(page_title="Facial Analysis", page_icon="👤", layout="wide")
try:
from src.cnnClassifier.pipeline.prediction import PredictionPipeline
except ImportError:
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
if src_path not in sys.path: sys.path.append(src_path)
from cnnClassifier.pipeline.prediction import PredictionPipeline
@st.cache_resource
def load_pipeline():
return PredictionPipeline()
pipeline = load_pipeline()
if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
def start_webcam(): st.session_state.webcam_running = True
def stop_webcam(): st.session_state.webcam_running = False
with st.sidebar:
st.markdown("## ⚙️ Controls")
app_mode = option_menu(None, ["Image", "Video", "Live Feed"],
icons=['image', 'film', 'camera-video'], menu_icon="cast", default_index=0)
if not pipeline:
st.error("AI Pipeline failed to load. Check terminal logs.")
else:
st.title("👤 Facial Demographics Analysis")
st.header(f"Mode: {app_mode}")
st.divider()
if app_mode == "Image":
uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1: st.image(image, caption='Original Image', use_column_width=True)
with col2:
with st.spinner('🔬 Analyzing with high-quality detector...'):
# --- THE FIX: Call the HQ method ---
annotated_image, predictions = pipeline.predict_hq(np.array(image))
st.image(annotated_image, caption='Processed Image', use_column_width=True)
if predictions:
with st.expander("View Details", expanded=True):
for i, p in enumerate(predictions):
st.write(f"**Face {i+1}:** Gender: `{p['gender']}`, Age Group: `{p['age']}`")
else: st.warning("No faces detected.")
elif app_mode == "Video":
uploaded_file = st.file_uploader("Upload a video for analysis", type=["mp4", "mov", "avi"])
if uploaded_file:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
st.info(f"Video has {frame_count} frames. This will be slow but high-quality.")
if st.button("Process Video", type="primary", use_container_width=True):
progress_bar = st.progress(0, text="Initializing...")
out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
for i in range(frame_count):
ret, frame = cap.read()
if not ret: break
# --- THE FIX: Call the HQ method ---
annotated_frame_rgb, _ = pipeline.predict_hq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
cap.release(), out.release()
st.success("Video processing complete!")
st.video(out_tfile.name)
with open(out_tfile.name, "rb") as f:
st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)
elif app_mode == "Live Feed":
st.info("Live feed uses a lightweight face detector for performance.")
col1, col2 = st.columns(2)
with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
_, center_col, _ = st.columns([1, 2, 1])
with center_col:
FRAME_WINDOW = st.image([])
fps_display = st.empty()
if st.session_state.webcam_running:
cap = cv2.VideoCapture(0)
while st.session_state.webcam_running:
start_time = time.time()
ret, frame = cap.read()
if not ret:
st.warning("Could not read frame from webcam. Stopping.")
stop_webcam()
break
frame = cv2.flip(frame, 1)
annotated_frame, _ = pipeline.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
FRAME_WINDOW.image(annotated_frame, channels="RGB")
fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
cap.release()
# --- THE FIX ---
# cv2.destroyAllWindows() # This line is removed
# --- END FIX ---
if st.session_state.webcam_running:
st.session_state.webcam_running = False
st.rerun()
|