ALYYAN's picture
FEAT: Finalize code for Hugging Face deployment
1f4e421
raw
history blame
5.84 kB
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import tensorflow as tf
import sys
import os
import tempfile
import time
from streamlit_option_menu import option_menu
# --- Page Config ---
st.set_page_config(page_title="Facial Analysis", page_icon="πŸ‘€", layout="wide", initial_sidebar_state="expanded")
# --- Path Setup & Model Loading ---
try:
# This works for local development
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
if src_path not in sys.path: sys.path.append(src_path)
from cnnClassifier.pipeline.prediction import PredictionPipeline
except ImportError:
# This is a fallback for Hugging Face Spaces
from src.cnnClassifier.pipeline.prediction import PredictionPipeline
# --- TF Config (for MTCNN in Image/Video modes) ---
try:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
except Exception: pass
@st.cache_resource
def load_pipeline():
return PredictionPipeline()
pipeline = load_pipeline()
if 'webcam_running' not in st.session_state: st.session_state.webcam_running = False
def start_webcam(): st.session_state.webcam_running = True
def stop_webcam(): st.session_state.webcam_running = False
# --- UI ---
with st.sidebar:
st.markdown("## βš™οΈ Controls")
app_mode = option_menu(None, ["Image", "Video", "Live Feed"],
icons=['image', 'film', 'camera-video'], menu_icon="cast", default_index=0)
if not pipeline:
st.error("AI Pipeline failed to load. Please check the terminal for errors.")
else:
st.title("πŸ‘€ Facial Demographics Analysis")
st.header(f"Mode: {app_mode}")
st.divider()
if app_mode == "Image":
uploaded_file = st.file_uploader("Upload an image for analysis", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
col1, col2 = st.columns(2)
with col1: st.image(image, caption='Original Image', use_column_width=True)
with col2:
with st.spinner('πŸ”¬ Analyzing with high-quality detector...'):
# --- THE FIX: Call the HQ method ---
annotated_image, predictions = pipeline.predict_hq(np.array(image))
st.image(annotated_image, caption='Processed Image', use_column_width=True)
if predictions:
with st.expander("View Details", expanded=True):
for i, p in enumerate(predictions):
st.write(f"**Face {i+1}:** Gender: `{p['gender']}`, Age Group: `{p['age']}`")
else: st.warning("No faces detected.")
elif app_mode == "Video":
uploaded_file = st.file_uploader("Upload a video for analysis", type=["mp4", "mov", "avi"])
if uploaded_file:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
st.info(f"Video has {frame_count} frames. This will be slow but high-quality.")
if st.button("Process Video", type="primary", use_container_width=True):
progress_bar = st.progress(0, text="Initializing...")
out_tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
h, w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
out = cv2.VideoWriter(out_tfile.name, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), (w, h))
for i in range(frame_count):
ret, frame = cap.read()
if not ret: break
# --- THE FIX: Call the HQ method ---
annotated_frame_rgb, _ = pipeline.predict_hq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
out.write(cv2.cvtColor(annotated_frame_rgb, cv2.COLOR_RGB2BGR))
progress_bar.progress((i + 1) / frame_count, text=f"Processing Frame {i+1}/{frame_count}")
cap.release(), out.release()
st.success("Video processing complete!")
st.video(out_tfile.name)
with open(out_tfile.name, "rb") as f:
st.download_button("Download Processed Video", f, "output.mp4", "video/mp4", use_container_width=True)
elif app_mode == "Live Feed":
st.info("Live feed uses a lightweight face detector for higher FPS.")
col1, col2 = st.columns(2)
with col1: st.button("Start Feed", on_click=start_webcam, use_container_width=True, type="primary")
with col2: st.button("Stop Feed", on_click=stop_webcam, use_container_width=True)
_, center_col, _ = st.columns([1, 2, 1])
with center_col:
FRAME_WINDOW = st.image([])
fps_display = st.empty()
if st.session_state.webcam_running:
cap = cv2.VideoCapture(0)
while st.session_state.webcam_running:
start_time = time.time()
ret, frame = cap.read()
if not ret: break
frame = cv2.flip(frame, 1)
# --- THE FIX: Call the LQ method ---
annotated_frame, _ = pipeline.predict_lq(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
FRAME_WINDOW.image(annotated_frame, channels="RGB")
fps = 1.0 / (time.time() - start_time) if (time.time() - start_time) > 0 else 0
fps_display.markdown(f"<p style='text-align: center;'><b>FPS: {fps:.2f}</b></p>", unsafe_allow_html=True)
cap.release()
cv2.destroyAllWindows()
st.session_state.webcam_running = False
st.rerun()