File size: 2,860 Bytes
53b616c
 
 
 
 
 
 
752e7ef
 
6afd62b
a0d17ed
752e7ef
 
 
 
 
 
 
53b616c
 
 
 
752e7ef
 
 
 
53b616c
 
752e7ef
 
2460901
752e7ef
53b616c
 
 
 
 
 
 
 
 
 
 
 
752e7ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b616c
 
 
 
 
 
 
 
752e7ef
53b616c
 
 
 
 
752e7ef
 
53b616c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import cv2
import numpy as np
import pickle
import torch
import time
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
import seaborn as sns

# Add YOLOv7 repository to the system path
sys.path.append(os.path.join(os.getcwd(), 'yolov7'))

from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_coords
from utils.datasets import letterbox

# Load ML model
model = pickle.load(open('model.pkl', 'rb'))

# Load YOLOv7 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
yolo_model = attempt_load('yolov7.pt', map_location=device)
yolo_model.eval()

# Streamlit UI setup
st.set_page_config(page_title="Multi-Face Attention Detector", layout='wide')
st.title("πŸŽ₯ Real-Time Multi-Face Attention Detector")
run = st.checkbox('Start Webcam')

FRAME_WINDOW = st.image([])
attention_log = []
start_time = time.time()

if run:
    cap = cv2.VideoCapture(0)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            st.warning("⚠️ Cannot access webcam.")
            break

        img = letterbox(frame, new_shape=640)[0]
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(device)
        img = img.float()
        img /= 255.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        pred = yolo_model(img)[0]
        pred = non_max_suppression(pred, 0.25, 0.45, classes=None, agnostic=False)

        # Process detections
        for i, det in enumerate(pred):
            if len(det):
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], frame.shape).round()
                for *xyxy, conf, cls in reversed(det):
                    label = f'{int(cls)} {conf:.2f}'
                    cv2.rectangle(frame, (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])), (0, 255, 0), 2)
                    cv2.putText(frame, label, (int(xyxy[0]), int(xyxy[1]) - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

        FRAME_WINDOW.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()

# Process log for dashboard
if attention_log:
    df = pd.DataFrame(attention_log)
    attentive = df[df['state'] == 'Attentive'].shape[0]
    inattentive = df[df['state'] == 'Inattentive'].shape[0]
    st.markdown("### πŸ“Š Attention Statistics")
    st.write(f"βœ… Attentive detections: {attentive}")
    st.write(f"⚠️ Inattentive detections: {inattentive}")
    st.dataframe(df.tail(10))
    st.line_chart(df.groupby('time')['state'].apply(lambda x: (x == 'Attentive').mean()))
    st.download_button("Download Log as CSV", df.to_csv(index=False), file_name="attention_log.csv")