File size: 5,632 Bytes
bdb70cc
 
 
 
eacd6a2
 
 
 
 
bdb70cc
38d1816
bdb70cc
 
38d1816
 
eacd6a2
 
 
1f4e421
eacd6a2
 
8cf2c66
eacd6a2
 
 
 
 
 
 
 
1f4e421
 
eacd6a2
 
8cf2c66
eacd6a2
1f4e421
 
 
 
eacd6a2
 
 
 
 
 
 
1f4e421
 
 
eacd6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
1f4e421
 
eacd6a2
 
 
 
1f4e421
 
 
 
 
eacd6a2
 
 
 
 
 
 
 
 
8cf2c66
eacd6a2
 
 
 
 
 
 
 
 
 
 
 
8cf2c66
 
 
 
eacd6a2
8cf2c66
eacd6a2
 
 
8cf2c66
eacd6a2
8cf2c66
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import sys
import os
import tempfile
import time
from streamlit_option_menu import option_menu

st.set_page_config(page_title="Facial Analysis", page_icon="👤", layout="wide")

try:
    from src.cnnClassifier.pipeline.prediction import PredictionPipeline
except ImportError:
    src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
    if src_path not in sys.path: sys.path.append(src_path)
    from cnnClassifier.pipeline.prediction import PredictionPipeline

@st.cache_resource
def load_pipeline():
    return PredictionPipeline()
pipeline = load_pipeline()

if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
def start_webcam(): st.session_state.webcam_running = True
def stop_webcam(): st.session_state.webcam_running = False

with st.sidebar:
    st.markdown("## ⚙️ Controls")
    app_mode = option_menu(None, ["Image", "Video", "Live Feed"], 
        icons=['image', 'film', 'camera-video'], menu_icon="cast", default_index=0)

if not pipeline:
    st.error("AI Pipeline failed to load. Check terminal logs.")
else:
    st.title("👤 Facial Demographics Analysis")
    st.header(f"Mode: {app_mode}")
    st.divider()

    if app_mode == "Image":
        uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
        if uploaded_file:
            image = Image.open(uploaded_file).convert("RGB")
            col1, col2 = st.columns(2)
            with col1: st.image(image, caption='Original Image', use_column_width=True)
            with col2:
                with st.spinner('🔬 Analyzing with high-quality detector...'):
                    # --- THE FIX: Call the HQ method ---
                    annotated_image, predictions = pipeline.predict_hq(np.array(image))
                st.image(annotated_image, caption='Processed Image', use_column_width=True)
                if predictions:
                    with st.expander("View Details", expanded=True):
                        for i, p in enumerate(predictions):
                            st.write(f"**Face {i+1}:** Gender: `{p['gender']}`, Age Group: `{p['age']}`")
                else: st.warning("No faces detected.")

    elif app_mode == "Video":
        uploaded_file = st.file_uploader("Upload a video for analysis", type=["mp4", "mov", "avi"])
        if uploaded_file:
            tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
            tfile.write(uploaded_file.read())
            cap = cv2.VideoCapture(tfile.name)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            st.info(f"Video has {frame_count} frames. This will be slow but high-quality.")
            if st.button("Process Video", type="primary", use_container_width=True):
                progress_bar = st.progress(0, text="Initializing...")
                out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
                h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
                for i in range(frame_count):
                    ret, frame = cap.read()
                    if not ret: break
                    # --- THE FIX: Call the HQ method ---
                    annotated_frame_rgb, _ = pipeline.predict_hq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
                    progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
                cap.release(), out.release()
                st.success("Video processing complete!")
                st.video(out_tfile.name)
                with open(out_tfile.name, "rb") as f:
                    st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)

    elif app_mode == "Live Feed":
        st.info("Live feed uses a lightweight face detector for performance.")
        col1, col2 = st.columns(2)
        with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
        with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
        _, center_col, _ = st.columns([1, 2, 1])
        with center_col:
            FRAME_WINDOW = st.image([])
            fps_display = st.empty()
        if st.session_state.webcam_running:
            cap = cv2.VideoCapture(0)
            while st.session_state.webcam_running:
                start_time = time.time()
                ret, frame = cap.read()
                if not ret: 
                    st.warning("Could not read frame from webcam. Stopping.")
                    stop_webcam()
                    break
                frame = cv2.flip(frame, 1)
                annotated_frame, _ = pipeline.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                FRAME_WINDOW.image(annotated_frame, channels="RGB")
                fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
                fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
            
            cap.release()
            # --- THE FIX ---
            # cv2.destroyAllWindows() # This line is removed
            # --- END FIX ---
            
            if st.session_state.webcam_running:
                st.session_state.webcam_running = False
                st.rerun()