Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import pickle | |
| import torch | |
| import time | |
| import pandas as pd | |
| import sys | |
| import os | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # Add YOLOv7 repository to the system path | |
| sys.path.append(os.path.join(os.getcwd(), 'yolov7')) | |
| from models.experimental import attempt_load | |
| from utils.general import non_max_suppression, scale_coords | |
| from utils.datasets import letterbox | |
| # Load ML model | |
| model = pickle.load(open('model.pkl', 'rb')) | |
| # Load YOLOv7 model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| yolo_model = attempt_load('yolov7.pt', map_location=device) | |
| yolo_model.eval() | |
| # Streamlit UI setup | |
| st.set_page_config(page_title="Multi-Face Attention Detector", layout='wide') | |
| st.title("π₯ Real-Time Multi-Face Attention Detector") | |
| run = st.checkbox('Start Webcam') | |
| FRAME_WINDOW = st.image([]) | |
| attention_log = [] | |
| start_time = time.time() | |
| if run: | |
| cap = cv2.VideoCapture(0) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| st.warning("β οΈ Cannot access webcam.") | |
| break | |
| img = letterbox(frame, new_shape=640)[0] | |
| img = img[:, :, ::-1].transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| img = torch.from_numpy(img).to(device) | |
| img = img.float() | |
| img /= 255.0 | |
| if img.ndimension() == 3: | |
| img = img.unsqueeze(0) | |
| # Inference | |
| pred = yolo_model(img)[0] | |
| pred = non_max_suppression(pred, 0.25, 0.45, classes=None, agnostic=False) | |
| # Process detections | |
| for i, det in enumerate(pred): | |
| if len(det): | |
| det[:, :4] = scale_coords(img.shape[2:], det[:, :4], frame.shape).round() | |
| for *xyxy, conf, cls in reversed(det): | |
| label = f'{int(cls)} {conf:.2f}' | |
| cv2.rectangle(frame, (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])), (0, 255, 0), 2) | |
| cv2.putText(frame, label, (int(xyxy[0]), int(xyxy[1]) - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | |
| FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| if cv2.waitKey(1) & 0xFF == ord('q'): | |
| break | |
| cap.release() | |
| # Process log for dashboard | |
| if attention_log: | |
| df = pd.DataFrame(attention_log) | |
| attentive = df[df['state'] == 'Attentive'].shape[0] | |
| inattentive = df[df['state'] == 'Inattentive'].shape[0] | |
| st.markdown("### π Attention Statistics") | |
| st.write(f"β Attentive detections: {attentive}") | |
| st.write(f"β οΈ Inattentive detections: {inattentive}") | |
| st.dataframe(df.tail(10)) | |
| st.line_chart(df.groupby('time')['state'].apply(lambda x: (x == 'Attentive').mean())) | |
| st.download_button("Download Log as CSV", df.to_csv(index=False), file_name="attention_log.csv") | |