File size: 9,997 Bytes
8bbb872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from __future__ import annotations

import os
from pathlib import Path

import cv2
import numpy as np
from ultralytics import YOLO

try:
    import mediapipe as mp
except Exception:  # pragma: no cover
    mp = None


def find_weights(project_root: Path) -> Path | None:
    candidates = [
        project_root / "weights" / "best.pt",
        project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
        project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
        project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
        project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
    ]
    return next((p for p in candidates if p.is_file()), None)


def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
    h, w = gray.shape
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    eq = clahe.apply(gray)
    blur = cv2.GaussianBlur(eq, (7, 7), 0)

    cx, cy = w // 2, h // 2
    rx, ry = int(w * 0.3), int(h * 0.3)
    x0, x1 = max(cx - rx, 0), min(cx + rx, w)
    y0, y1 = max(cy - ry, 0), min(cy + ry, h)
    roi = blur[y0:y1, x0:x1]

    _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)

    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None

    best = None
    best_score = -1.0
    for c in contours:
        area = cv2.contourArea(c)
        if area < 15:
            continue
        perimeter = cv2.arcLength(c, True)
        if perimeter <= 0:
            continue
        circularity = 4 * np.pi * (area / (perimeter * perimeter))
        if circularity < 0.3:
            continue
        m = cv2.moments(c)
        if m["m00"] == 0:
            continue
        px = int(m["m10"] / m["m00"]) + x0
        py = int(m["m01"] / m["m00"]) + y0

        dist = np.hypot(px - cx, py - cy) / max(w, h)
        score = circularity - dist
        if score > best_score:
            best_score = score
            best = (px, py)

    return best


def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
    h, w = img_shape
    cx = w // 2
    px, _ = pupil_center
    dx = abs(px - cx) / max(w, 1)
    return dx < 0.12


def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]:
    # Use classifier directly on frame (assumes frame is eye crop)
    results = model.predict(frame, imgsz=224, device="cpu", verbose=False)
    r = results[0]
    probs = r.probs
    top_idx = int(probs.top1)
    top_conf = float(probs.top1conf)
    pred_label = model.names[top_idx]
    return pred_label, top_conf


def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float):
    out = frame.copy()
    text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s"
    cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
    return out


def write_segments(path: Path, segments: list[tuple[float, float, str]]):
    with path.open("w") as f:
        for start, end, label in segments:
            f.write(f"{start:.2f},{end:.2f},{label}\n")


def process_video(video_path: Path, model: YOLO | None):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        print(f"Failed to open {video_path}")
        return

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    out_path = video_path.with_name(video_path.stem + "_pred.mp4")
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))

    csv_path = video_path.with_name(video_path.stem + "_predictions.csv")
    seg_path = video_path.with_name(video_path.stem + "_segments.txt")

    frame_idx = 0
    last_label = None
    seg_start = 0.0
    segments: list[tuple[float, float, str]] = []

    with csv_path.open("w") as fcsv:
        fcsv.write("time_sec,label,focused,conf\n")
        if mp is None:
            print("mediapipe is not installed. Falling back to classifier-only mode.")
        use_mp = mp is not None
        if use_mp:
            mp_face_mesh = mp.solutions.face_mesh
            face_mesh = mp_face_mesh.FaceMesh(
                static_image_mode=False,
                max_num_faces=1,
                refine_landmarks=True,
                min_detection_confidence=0.5,
                min_tracking_confidence=0.5,
            )

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            time_sec = frame_idx / fps
            conf = 0.0
            pred_label = "open"
            focused = False

            if use_mp:
                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                res = face_mesh.process(rgb)
                if res.multi_face_landmarks:
                    lm = res.multi_face_landmarks[0].landmark
                    h, w = frame.shape[:2]

                    # Eye landmarks (MediaPipe FaceMesh)
                    left_eye = [33, 160, 158, 133, 153, 144]
                    right_eye = [362, 385, 387, 263, 373, 380]
                    left_iris = [468, 469, 470, 471]
                    right_iris = [473, 474, 475, 476]

                    def pts(idxs):
                        return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs])

                    def ear(eye_pts):
                        # EAR using 6 points
                        p1, p2, p3, p4, p5, p6 = eye_pts
                        v1 = np.linalg.norm(p2 - p6)
                        v2 = np.linalg.norm(p3 - p5)
                        h1 = np.linalg.norm(p1 - p4)
                        return (v1 + v2) / (2.0 * h1 + 1e-6)

                    le = pts(left_eye)
                    re = pts(right_eye)
                    le_ear = ear(le)
                    re_ear = ear(re)
                    ear_avg = (le_ear + re_ear) / 2.0

                    # openness threshold
                    pred_label = "open" if ear_avg > 0.22 else "closed"

                    # iris centers
                    li = pts(left_iris)
                    ri = pts(right_iris)
                    li_c = li.mean(axis=0).astype(int)
                    ri_c = ri.mean(axis=0).astype(int)

                    # eye centers (midpoint of corners)
                    le_c = ((le[0] + le[3]) / 2).astype(int)
                    re_c = ((re[0] + re[3]) / 2).astype(int)

                    # focus = iris close to eye center horizontally for both eyes
                    le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1)
                    re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1)
                    focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18)

                    # draw eye boundaries
                    cv2.polylines(frame, [le], True, (0, 255, 255), 1)
                    cv2.polylines(frame, [re], True, (0, 255, 255), 1)
                    # draw iris centers
                    cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1)
                    cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1)
                else:
                    pred_label = "closed"
                    focused = False
            else:
                if model is not None:
                    pred_label, conf = classify_frame(model, frame)
                gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None
                focused = False
                if pred_label.lower() == "open" and pupil_center is not None:
                    focused = is_focused(pupil_center, gray.shape)

            if pred_label.lower() != "open":
                focused = False

            label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused"
            if pred_label.lower() != "open":
                label = "closed_not_focused"

            fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n")

            if last_label is None:
                last_label = label
                seg_start = time_sec
            elif label != last_label:
                segments.append((seg_start, time_sec, last_label))
                seg_start = time_sec
                last_label = label

            annotated = annotate_frame(frame, label, focused, conf, time_sec)
            writer.write(annotated)
            frame_idx += 1

    if last_label is not None:
        end_time = frame_idx / fps
        segments.append((seg_start, end_time, last_label))
    write_segments(seg_path, segments)

    cap.release()
    writer.release()
    print(f"Saved: {out_path}")
    print(f"CSV: {csv_path}")
    print(f"Segments: {seg_path}")


def main():
    project_root = Path(__file__).resolve().parent.parent
    weights = find_weights(project_root)
    model = YOLO(str(weights)) if weights is not None else None

    # Default to 1.mp4 and 2.mp4 in project root
    videos = []
    for name in ["1.mp4", "2.mp4"]:
        p = project_root / name
        if p.exists():
            videos.append(p)

    # Also allow passing paths via env var
    extra = os.getenv("VIDEOS", "")
    for v in [x.strip() for x in extra.split(",") if x.strip()]:
        vp = Path(v)
        if not vp.is_absolute():
            vp = project_root / vp
        if vp.exists():
            videos.append(vp)

    if not videos:
        print("No videos found. Expected 1.mp4 / 2.mp4 in project root.")
        return

    for v in videos:
        process_video(v, model)


if __name__ == "__main__":
    main()