JAI / app.py
Deevyankar's picture
Update app.py
a0d17ed verified
import streamlit as st
import cv2
import numpy as np
import pickle
import torch
import time
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
import seaborn as sns
# Add YOLOv7 repository to the system path
sys.path.append(os.path.join(os.getcwd(), 'yolov7'))
from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_coords
from utils.datasets import letterbox
# Load ML model
model = pickle.load(open('model.pkl', 'rb'))
# Load YOLOv7 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
yolo_model = attempt_load('yolov7.pt', map_location=device)
yolo_model.eval()
# Streamlit UI setup
st.set_page_config(page_title="Multi-Face Attention Detector", layout='wide')
st.title("πŸŽ₯ Real-Time Multi-Face Attention Detector")
run = st.checkbox('Start Webcam')
FRAME_WINDOW = st.image([])
attention_log = []
start_time = time.time()
if run:
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
st.warning("⚠️ Cannot access webcam.")
break
img = letterbox(frame, new_shape=640)[0]
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
pred = yolo_model(img)[0]
pred = non_max_suppression(pred, 0.25, 0.45, classes=None, agnostic=False)
# Process detections
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], frame.shape).round()
for *xyxy, conf, cls in reversed(det):
label = f'{int(cls)} {conf:.2f}'
cv2.rectangle(frame, (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])), (0, 255, 0), 2)
cv2.putText(frame, label, (int(xyxy[0]), int(xyxy[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
# Process log for dashboard
if attention_log:
df = pd.DataFrame(attention_log)
attentive = df[df['state'] == 'Attentive'].shape[0]
inattentive = df[df['state'] == 'Inattentive'].shape[0]
st.markdown("### πŸ“Š Attention Statistics")
st.write(f"βœ… Attentive detections: {attentive}")
st.write(f"⚠️ Inattentive detections: {inattentive}")
st.dataframe(df.tail(10))
st.line_chart(df.groupby('time')['state'].apply(lambda x: (x == 'Attentive').mean()))
st.download_button("Download Log as CSV", df.to_csv(index=False), file_name="attention_log.csv")