ALYYAN's picture
Backend + Frontend done
eacd6a2
raw
history blame
5.86 kB
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
import sys
import os
import tempfile
import time
from streamlit_option_menu import option_menu
# --- Page Config (Set once at the top) ---
st.set_page_config(page_title="Facial Analysis", page_icon="πŸ‘€", layout="wide", initial_sidebar_state="expanded")
# --- Backend Loading (Robust and Unchanged) ---
try:
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
if src_path not in sys.path: sys.path.append(src_path)
from cnnClassifier.pipeline.prediction import PredictionPipeline
except ImportError:
st.error("FATAL: Prediction pipeline not found. Check project structure.")
st.stop()
try:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
except Exception: pass
@st.cache_resource
def load_pipeline():
return PredictionPipeline()
pipeline = load_pipeline()
# --- Session State for Webcam Control ---
if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
def start_webcam(): st.session_state.webcam_running = True
def stop_webcam(): st.session_state.webcam_running = False
# --- Sidebar UI (Clean and Themed) ---
with st.sidebar:
st.markdown("## βš™οΈ Controls")
app_mode = option_menu(
menu_title=None,
options=["Image", "Video", "Live Feed"],
icons=["image", "film", "camera-video"],
menu_icon="cast",
default_index=0,
)
st.divider()
st.info("This app uses a multi-task EfficientNet model to predict age and gender.")
# --- Main Page Content ---
st.title(f"πŸ‘€ Facial Demographics Analysis")
st.markdown(f"### Mode: {app_mode}")
st.divider()
if not pipeline:
st.error("AI Pipeline failed to load. Please check the terminal for errors.")
else:
if app_mode == "Image":
uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1: st.image(image, caption='Original Image', use_column_width=True)
with col2:
with st.spinner('πŸ”¬ Analyzing...'):
annotated_image, predictions = pipeline.predict_image(np.array(image))
st.image(annotated_image, caption='Processed Image', use_column_width=True)
if predictions:
with st.expander("View Details", expanded=True):
for i, p in enumerate(predictions):
st.write(f"**Face {i+1}:** Gender: `{p['gender']}`, Age Group: `{p['age']}`")
else: st.warning("No faces detected.")
elif app_mode == "Video":
uploaded_file = st.file_uploader("Upload a video for analysis", type=["mp4", "mov", "avi"])
if uploaded_file:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
st.info(f"Video has {frame_count} frames.")
if st.button("Start Video Processing", type="primary", use_container_width=True):
progress_bar = st.progress(0, text="Initializing...")
out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
def frame_generator():
for _ in range(frame_count):
ret, frame = cap.read()
if not ret: break
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
for i, annotated_frame_rgb in enumerate(pipeline.process_video_stream(frame_generator())):
out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
cap.release(), out.release()
st.success("Video processing complete!")
st.video(out_tfile.name)
with open(out_tfile.name, "rb") as f:
st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)
elif app_mode == "Live Feed":
col1, col2 = st.columns(2)
with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
_, center_col, _ = st.columns([1, 2, 1])
with center_col:
FRAME_WINDOW = st.image([])
fps_display = st.empty()
if st.session_state.webcam_running:
cap = cv2.VideoCapture(0)
while st.session_state.webcam_running:
start_time = time.time()
ret, frame = cap.read()
if not ret: break
frame = cv2.flip(frame, 1)
annotated_frame = pipeline.process_live_frame(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
FRAME_WINDOW.image(annotated_frame, channels="RGB")
fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
cap.release()
cv2.destroyAllWindows()
st.session_state.webcam_running = False
st.rerun()