Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import sys | |
| import os | |
| import tempfile | |
| import time | |
| from streamlit_option_menu import option_menu | |
| st.set_page_config(page_title="Facial Analysis", page_icon="π€", layout="wide") | |
| try: | |
| from src.cnnClassifier.pipeline.prediction import PredictionPipeline | |
| except ImportError: | |
| src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) | |
| if src_path not in sys.path: sys.path.append(src_path) | |
| from cnnClassifier.pipeline.prediction import PredictionPipeline | |
| def load_pipeline(): | |
| return PredictionPipeline() | |
| pipeline = load_pipeline() | |
| if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False | |
| def start_webcam(): st.session_state.webcam_running = True | |
| def stop_webcam(): st.session_state.webcam_running = False | |
| with st.sidebar: | |
| st.markdown("## βοΈ Controls") | |
| app_mode = option_menu(None, ["Image", "Video", "Live Feed"], | |
| icons=['image', 'film', 'camera-video'], menu_icon="cast", default_index=0) | |
| if not pipeline: | |
| st.error("AI Pipeline failed to load. Check terminal logs.") | |
| else: | |
| st.title("π€ Facial Demographics Analysis") | |
| st.header(f"Mode: {app_mode}") | |
| st.divider() | |
| if app_mode == "Image": | |
| uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| col1, col2 = st.columns(2) | |
| with col1: st.image(image, caption='Original Image', use_column_width=True) | |
| with col2: | |
| with st.spinner('π¬ Analyzing with high-quality detector...'): | |
| # --- THE FIX: Call the HQ method --- | |
| annotated_image, predictions = pipeline.predict_hq(np.array(image)) | |
| st.image(annotated_image, caption='Processed Image', use_column_width=True) | |
| if predictions: | |
| with st.expander("View Details", expanded=True): | |
| for i, p in enumerate(predictions): | |
| st.write(f"**Face {i+1}:** Gender: `{p['gender']}`, Age Group: `{p['age']}`") | |
| else: st.warning("No faces detected.") | |
| elif app_mode == "Video": | |
| uploaded_file = st.file_uploader("Upload a video for analysis", type=["mp4", "mov", "avi"]) | |
| if uploaded_file: | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| tfile.write(uploaded_file.read()) | |
| cap = cv2.VideoCapture(tfile.name) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| st.info(f"Video has {frame_count} frames. This will be slow but high-quality.") | |
| if st.button("Process Video", type="primary", use_container_width=True): | |
| progress_bar = st.progress(0, text="Initializing...") | |
| out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h)) | |
| for i in range(frame_count): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| # --- THE FIX: Call the HQ method --- | |
| annotated_frame_rgb, _ = pipeline.predict_hq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR)) | |
| progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}") | |
| cap.release(), out.release() | |
| st.success("Video processing complete!") | |
| st.video(out_tfile.name) | |
| with open(out_tfile.name, "rb") as f: | |
| st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True) | |
| elif app_mode == "Live Feed": | |
| st.info("Live feed uses a lightweight face detector for performance.") | |
| col1, col2 = st.columns(2) | |
| with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary") | |
| with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True) | |
| _, center_col, _ = st.columns([1, 2, 1]) | |
| with center_col: | |
| FRAME_WINDOW = st.image([]) | |
| fps_display = st.empty() | |
| if st.session_state.webcam_running: | |
| cap = cv2.VideoCapture(0) | |
| while st.session_state.webcam_running: | |
| start_time = time.time() | |
| ret, frame = cap.read() | |
| if not ret: | |
| st.warning("Could not read frame from webcam. Stopping.") | |
| stop_webcam() | |
| break | |
| frame = cv2.flip(frame, 1) | |
| annotated_frame, _ = pipeline.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| FRAME_WINDOW.image(annotated_frame, channels="RGB") | |
| fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0 | |
| fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True) | |
| cap.release() | |
| # --- THE FIX --- | |
| # cv2.destroyAllWindows() # This line is removed | |
| # --- END FIX --- | |
| if st.session_state.webcam_running: | |
| st.session_state.webcam_running = False | |
| st.rerun() | |