Olof Astrand commited on
Commit ·
c17a15c
1
Parent(s): 0eda6ef
Minor updates
Browse files- README.md +51 -2
- inference_claude.py +319 -0
- training_faces.py +349 -0
README.md
CHANGED
|
@@ -13,8 +13,11 @@ Creating a dataset
|
|
| 13 |
==================
|
| 14 |
collector.py
|
| 15 |
collector.html
|
| 16 |
-
When creating a dataset in the browser you will have to convert
|
| 17 |
convert.py
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
Training from web based dataset
|
|
@@ -23,10 +26,56 @@ training.py
|
|
| 23 |
|
| 24 |
Training from OpenCV created dataset
|
| 25 |
==============
|
|
|
|
|
|
|
| 26 |
training_deepseek.py
|
| 27 |
|
| 28 |
|
| 29 |
Inference
|
| 30 |
==========
|
| 31 |
inference.py
|
| 32 |
-
This does not work in a wsl environment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
==================
|
| 14 |
collector.py
|
| 15 |
collector.html
|
| 16 |
+
When creating a dataset in the browser you will have to convert it with
|
| 17 |
convert.py
|
| 18 |
+
When creating dataset with collector you have to preprocess it with
|
| 19 |
+
preprocess.py
|
| 20 |
+
|
| 21 |
|
| 22 |
|
| 23 |
Training from web based dataset
|
|
|
|
| 26 |
|
| 27 |
Training from OpenCV created dataset
|
| 28 |
==============
|
| 29 |
+
training_faces.py
|
| 30 |
+
|
| 31 |
training_deepseek.py
|
| 32 |
|
| 33 |
|
| 34 |
Inference
|
| 35 |
==========
|
| 36 |
inference.py
|
| 37 |
+
This does not work in a wsl environment as we cannot access the camera.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Mobilenet used
|
| 41 |
+
Model architecture:
|
| 42 |
+
Model: "functional"
|
| 43 |
+
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
|
| 44 |
+
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
|
| 45 |
+
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
|
| 46 |
+
│ input_layer (InputLayer) │ (None, 60, 80, 3) │ 0 │
|
| 47 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 48 |
+
│ conv2d (Conv2D) │ (None, 60, 80, 32) │ 896 │
|
| 49 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 50 |
+
│ max_pooling2d (MaxPooling2D) │ (None, 30, 40, 32) │ 0 │
|
| 51 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 52 |
+
│ dropout (Dropout) │ (None, 30, 40, 32) │ 0 │
|
| 53 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 54 |
+
│ conv2d_1 (Conv2D) │ (None, 30, 40, 64) │ 18,496 │
|
| 55 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 56 |
+
│ max_pooling2d_1 (MaxPooling2D) │ (None, 15, 20, 64) │ 0 │
|
| 57 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 58 |
+
│ dropout_1 (Dropout) │ (None, 15, 20, 64) │ 0 │
|
| 59 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 60 |
+
│ conv2d_2 (Conv2D) │ (None, 15, 20, 128) │ 73,856 │
|
| 61 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 62 |
+
│ max_pooling2d_2 (MaxPooling2D) │ (None, 7, 10, 128) │ 0 │
|
| 63 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 64 |
+
│ dropout_2 (Dropout) │ (None, 7, 10, 128) │ 0 │
|
| 65 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 66 |
+
│ global_average_pooling2d │ (None, 128) │ 0 │
|
| 67 |
+
│ (GlobalAveragePooling2D) │ │ │
|
| 68 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 69 |
+
│ dense (Dense) │ (None, 128) │ 16,512 │
|
| 70 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 71 |
+
│ dropout_3 (Dropout) │ (None, 128) │ 0 │
|
| 72 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 73 |
+
│ dense_1 (Dense) │ (None, 64) │ 8,256 │
|
| 74 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 75 |
+
│ dropout_4 (Dropout) │ (None, 64) │ 0 │
|
| 76 |
+
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
|
| 77 |
+
│ dense_2 (Dense) │ (None, 2) │ 130 │
|
| 78 |
+
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
|
| 79 |
+
Total params: 118,146 (461.51 KB)
|
| 80 |
+
Trainable params: 118,146 (461.51 KB)
|
| 81 |
+
|
inference_claude.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from tensorflow import keras
|
| 5 |
+
import tkinter as tk
|
| 6 |
+
|
| 7 |
+
class GazeInference:
|
| 8 |
+
def __init__(self, model_path):
|
| 9 |
+
# Load model
|
| 10 |
+
self.model = keras.models.load_model(model_path)
|
| 11 |
+
|
| 12 |
+
# Initialize face cascade classifier
|
| 13 |
+
self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
| 14 |
+
|
| 15 |
+
# Get actual screen resolution using tkinter
|
| 16 |
+
self._get_screen_resolution()
|
| 17 |
+
|
| 18 |
+
# Gaze smoothing
|
| 19 |
+
self.smoothing = True
|
| 20 |
+
self.smoothing_factor = 0.2
|
| 21 |
+
self.last_gaze = None
|
| 22 |
+
|
| 23 |
+
# Face visualization
|
| 24 |
+
self.show_face = False
|
| 25 |
+
self.face_roi = None
|
| 26 |
+
self.upper_face_roi = None
|
| 27 |
+
|
| 28 |
+
# Face detection parameters
|
| 29 |
+
self.min_face_size = (50, 50)
|
| 30 |
+
self.scale_factor = 1.1
|
| 31 |
+
self.min_neighbors = 5
|
| 32 |
+
|
| 33 |
+
# Adjustable crop parameters (can be modified with keys)
|
| 34 |
+
self.crop_top = 0.05 # Start at 25% from top of face
|
| 35 |
+
self.crop_bottom = 0.80 # End at 55% from top of face
|
| 36 |
+
self.crop_sides = 0.05 # Crop 15% from each side
|
| 37 |
+
|
| 38 |
+
def _get_screen_resolution(self):
|
| 39 |
+
"""Get primary screen resolution using tkinter"""
|
| 40 |
+
root = tk.Tk()
|
| 41 |
+
self.screen_width = root.winfo_screenwidth()
|
| 42 |
+
self.screen_height = root.winfo_screenheight()
|
| 43 |
+
root.destroy()
|
| 44 |
+
print(f"Screen resolution: {self.screen_width}x{self.screen_height}")
|
| 45 |
+
|
| 46 |
+
def _extract_upper_face_region(self, frame):
|
| 47 |
+
"""Extract upper half of face region from frame"""
|
| 48 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 49 |
+
|
| 50 |
+
# Detect faces
|
| 51 |
+
faces = self.face_cascade.detectMultiScale(
|
| 52 |
+
gray,
|
| 53 |
+
scaleFactor=self.scale_factor,
|
| 54 |
+
minNeighbors=self.min_neighbors,
|
| 55 |
+
minSize=self.min_face_size
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if len(faces) > 0:
|
| 59 |
+
# Get the largest face (assuming it's the closest/main face)
|
| 60 |
+
face = max(faces, key=lambda x: x[2] * x[3])
|
| 61 |
+
x, y, w, h = face
|
| 62 |
+
|
| 63 |
+
# Store full face for visualization
|
| 64 |
+
face_padding = 10
|
| 65 |
+
x_start = max(0, x - face_padding)
|
| 66 |
+
y_start = max(0, y - face_padding)
|
| 67 |
+
x_end = min(frame.shape[1], x + w + face_padding)
|
| 68 |
+
y_end = min(frame.shape[0], y + h + face_padding)
|
| 69 |
+
self.face_roi = frame[y_start:y_end, x_start:x_end].copy()
|
| 70 |
+
|
| 71 |
+
# Extract tight crop around eyes region to match training data
|
| 72 |
+
# Based on typical face proportions:
|
| 73 |
+
# - Eyes are typically at 40-50% down from top of face
|
| 74 |
+
# - Eye region height is about 20-30% of face height
|
| 75 |
+
|
| 76 |
+
# Use adjustable crop parameters
|
| 77 |
+
eye_region_start = int(h * self.crop_top)
|
| 78 |
+
eye_region_end = int(h * self.crop_bottom)
|
| 79 |
+
|
| 80 |
+
# For width, focus on central portion of face
|
| 81 |
+
width_crop = int(w * self.crop_sides)
|
| 82 |
+
|
| 83 |
+
# Calculate bounds for eye region
|
| 84 |
+
uf_x_start = max(0, x + width_crop)
|
| 85 |
+
uf_y_start = max(0, y + eye_region_start)
|
| 86 |
+
uf_x_end = min(frame.shape[1], x + w - width_crop)
|
| 87 |
+
uf_y_end = min(frame.shape[0], y + eye_region_end)
|
| 88 |
+
|
| 89 |
+
# Extract and resize upper face region
|
| 90 |
+
upper_face = frame[uf_y_start:uf_y_end, uf_x_start:uf_x_end]
|
| 91 |
+
|
| 92 |
+
# Resize to model input size (80x60)
|
| 93 |
+
upper_face_resized = cv2.resize(upper_face, (80, 60))
|
| 94 |
+
self.upper_face_roi = upper_face_resized.copy()
|
| 95 |
+
|
| 96 |
+
return upper_face_resized
|
| 97 |
+
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def _predict_gaze(self, upper_face_region):
|
| 101 |
+
"""Predict gaze position from upper face region"""
|
| 102 |
+
# Preprocess
|
| 103 |
+
face_input = cv2.cvtColor(upper_face_region, cv2.COLOR_BGR2RGB)
|
| 104 |
+
face_input = face_input.astype('float32') / 255.0
|
| 105 |
+
face_input = np.expand_dims(face_input, axis=0)
|
| 106 |
+
|
| 107 |
+
# Predict
|
| 108 |
+
pred = self.model.predict(face_input, verbose=0)[0]
|
| 109 |
+
|
| 110 |
+
# Convert to screen coordinates
|
| 111 |
+
screen_x = int(pred[0] * self.screen_width)
|
| 112 |
+
screen_y = int(pred[1] * self.screen_height)
|
| 113 |
+
|
| 114 |
+
# Fix left-right inversion: flip X coordinate
|
| 115 |
+
# Since camera is mirrored, we need to invert the X prediction
|
| 116 |
+
screen_x = self.screen_width - screen_x
|
| 117 |
+
|
| 118 |
+
# Clamp to screen bounds
|
| 119 |
+
screen_x = max(0, min(self.screen_width - 1, screen_x))
|
| 120 |
+
screen_y = max(0, min(self.screen_height - 1, screen_y))
|
| 121 |
+
|
| 122 |
+
# Apply smoothing if enabled
|
| 123 |
+
if self.smoothing and self.last_gaze is not None:
|
| 124 |
+
screen_x = int(self.smoothing_factor * screen_x +
|
| 125 |
+
(1 - self.smoothing_factor) * self.last_gaze[0])
|
| 126 |
+
screen_y = int(self.smoothing_factor * screen_y +
|
| 127 |
+
(1 - self.smoothing_factor) * self.last_gaze[1])
|
| 128 |
+
|
| 129 |
+
self.last_gaze = (screen_x, screen_y)
|
| 130 |
+
return screen_x, screen_y
|
| 131 |
+
|
| 132 |
+
def _draw_gaze_cross(self, frame, x, y):
|
| 133 |
+
"""Draw crosshair at gaze position"""
|
| 134 |
+
color = (0, 255, 0) # Green
|
| 135 |
+
size = 30
|
| 136 |
+
thickness = 3
|
| 137 |
+
|
| 138 |
+
# Horizontal line
|
| 139 |
+
cv2.line(frame, (x - size, y), (x + size, y), color, thickness)
|
| 140 |
+
# Vertical line
|
| 141 |
+
cv2.line(frame, (x, y - size), (x, y + size), color, thickness)
|
| 142 |
+
# Center circle
|
| 143 |
+
cv2.circle(frame, (x, y), 5, color, -1)
|
| 144 |
+
|
| 145 |
+
def _draw_face_roi(self, frame):
|
| 146 |
+
"""Draw face region visualization in bottom left"""
|
| 147 |
+
if self.show_face and self.face_roi is not None and self.upper_face_roi is not None:
|
| 148 |
+
# Calculate display sizes
|
| 149 |
+
max_height = 200
|
| 150 |
+
|
| 151 |
+
# Display full face
|
| 152 |
+
face_h, face_w = self.face_roi.shape[:2]
|
| 153 |
+
display_h = min(max_height, self.screen_height // 4)
|
| 154 |
+
display_w = int(display_h * (face_w / face_h))
|
| 155 |
+
|
| 156 |
+
# Ensure we don't exceed screen dimensions
|
| 157 |
+
display_w = min(display_w, self.screen_width // 3)
|
| 158 |
+
display_h = min(display_h, self.screen_height // 3)
|
| 159 |
+
|
| 160 |
+
# Resize face for display
|
| 161 |
+
face_display = cv2.resize(self.face_roi, (display_w, display_h))
|
| 162 |
+
|
| 163 |
+
# Position for full face (bottom left)
|
| 164 |
+
face_y = self.screen_height - display_h - 10
|
| 165 |
+
face_x = 10
|
| 166 |
+
|
| 167 |
+
# Draw full face
|
| 168 |
+
try:
|
| 169 |
+
frame[face_y:face_y + display_h, face_x:face_x + display_w] = face_display
|
| 170 |
+
cv2.rectangle(frame, (face_x, face_y),
|
| 171 |
+
(face_x + display_w, face_y + display_h), (255, 255, 255), 2)
|
| 172 |
+
cv2.putText(frame, "Full Face", (face_x + 5, face_y - 5),
|
| 173 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 174 |
+
except:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
# Display upper face region
|
| 178 |
+
uf_h = int(display_h * 0.75) # 60:80 aspect ratio
|
| 179 |
+
uf_w = display_w
|
| 180 |
+
upper_face_display = cv2.resize(self.upper_face_roi, (uf_w, uf_h))
|
| 181 |
+
|
| 182 |
+
# Position for upper face (next to full face)
|
| 183 |
+
uf_x = face_x + display_w + 20
|
| 184 |
+
uf_y = self.screen_height - uf_h - 10
|
| 185 |
+
|
| 186 |
+
# Draw upper face
|
| 187 |
+
try:
|
| 188 |
+
frame[uf_y:uf_y + uf_h, uf_x:uf_x + uf_w] = upper_face_display
|
| 189 |
+
cv2.rectangle(frame, (uf_x, uf_y),
|
| 190 |
+
(uf_x + uf_w, uf_y + uf_h), (0, 255, 255), 2)
|
| 191 |
+
cv2.putText(frame, "Upper Face (Model Input)", (uf_x + 5, uf_y - 5),
|
| 192 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
| 193 |
+
except:
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
def run(self):
|
| 197 |
+
"""Main inference loop"""
|
| 198 |
+
cap = cv2.VideoCapture(0)
|
| 199 |
+
if not cap.isOpened():
|
| 200 |
+
print("Error: Could not open webcam")
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
# Create fullscreen window
|
| 204 |
+
cv2.namedWindow('Gaze Prediction', cv2.WND_PROP_FULLSCREEN)
|
| 205 |
+
cv2.setWindowProperty('Gaze Prediction', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
|
| 206 |
+
|
| 207 |
+
print("Controls:")
|
| 208 |
+
print("'s': Toggle Smoothing")
|
| 209 |
+
print("'f': Toggle Face View")
|
| 210 |
+
print("'q': Quit")
|
| 211 |
+
print("Crop Adjustment:")
|
| 212 |
+
print("'u'/'j': Move crop top up/down")
|
| 213 |
+
print("'i'/'k': Move crop bottom up/down")
|
| 214 |
+
print("'o'/'l': Decrease/increase side crop")
|
| 215 |
+
print(f"Using screen resolution: {self.screen_width}x{self.screen_height}")
|
| 216 |
+
|
| 217 |
+
while True:
|
| 218 |
+
ret, frame = cap.read()
|
| 219 |
+
if not ret:
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
# Create black canvas matching screen size
|
| 223 |
+
canvas = np.zeros((self.screen_height, self.screen_width, 3), dtype=np.uint8)
|
| 224 |
+
|
| 225 |
+
# Mirror the frame for more natural interaction
|
| 226 |
+
frame = cv2.flip(frame, 1)
|
| 227 |
+
|
| 228 |
+
# Process frame
|
| 229 |
+
upper_face = self._extract_upper_face_region(frame)
|
| 230 |
+
|
| 231 |
+
if upper_face is not None:
|
| 232 |
+
# Predict gaze
|
| 233 |
+
gaze_x, gaze_y = self._predict_gaze(upper_face)
|
| 234 |
+
|
| 235 |
+
# Draw gaze cross on canvas
|
| 236 |
+
self._draw_gaze_cross(canvas, gaze_x, gaze_y)
|
| 237 |
+
|
| 238 |
+
# Show coordinates for debugging
|
| 239 |
+
cv2.putText(canvas, f"Gaze: ({gaze_x}, {gaze_y})",
|
| 240 |
+
(20, self.screen_height - 30), cv2.FONT_HERSHEY_SIMPLEX,
|
| 241 |
+
0.7, (255, 255, 0), 2)
|
| 242 |
+
|
| 243 |
+
# Draw face regions if enabled
|
| 244 |
+
if self.show_face:
|
| 245 |
+
self._draw_face_roi(canvas)
|
| 246 |
+
else:
|
| 247 |
+
# Show "no face detected" message
|
| 248 |
+
cv2.putText(canvas, "No face detected",
|
| 249 |
+
(20, self.screen_height - 30), cv2.FONT_HERSHEY_SIMPLEX,
|
| 250 |
+
0.7, (0, 0, 255), 2)
|
| 251 |
+
|
| 252 |
+
# Show instructions
|
| 253 |
+
cv2.putText(canvas, "'s': Smoothing | 'f': Face View | 'u/j': Top | 'i/k': Bottom | 'o/l': Sides | 'q': Quit",
|
| 254 |
+
(20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 255 |
+
|
| 256 |
+
# Show smoothing status
|
| 257 |
+
status = "ON" if self.smoothing else "OFF"
|
| 258 |
+
cv2.putText(canvas, f"Smoothing: {status}",
|
| 259 |
+
(20, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
| 260 |
+
(0, 255, 0) if self.smoothing else (0, 0, 255), 2)
|
| 261 |
+
|
| 262 |
+
# Show face view status
|
| 263 |
+
status = "ON" if self.show_face else "OFF"
|
| 264 |
+
cv2.putText(canvas, f"Face View: {status}",
|
| 265 |
+
(20, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
|
| 266 |
+
(0, 255, 0) if self.show_face else (0, 0, 255), 2)
|
| 267 |
+
|
| 268 |
+
# Show screen resolution info and crop parameters
|
| 269 |
+
cv2.putText(canvas, f"Screen: {self.screen_width}x{self.screen_height}",
|
| 270 |
+
(20, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 271 |
+
cv2.putText(canvas, f"Crop: Top={self.crop_top:.2f} Bottom={self.crop_bottom:.2f} Sides={self.crop_sides:.2f}",
|
| 272 |
+
(20, 190), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2)
|
| 273 |
+
|
| 274 |
+
# Display
|
| 275 |
+
cv2.imshow('Gaze Prediction', canvas)
|
| 276 |
+
|
| 277 |
+
# Handle key presses
|
| 278 |
+
key = cv2.waitKey(1) & 0xFF
|
| 279 |
+
if key == ord('q'):
|
| 280 |
+
break
|
| 281 |
+
elif key == ord('s'):
|
| 282 |
+
self.smoothing = not self.smoothing
|
| 283 |
+
print(f"Smoothing: {'ON' if self.smoothing else 'OFF'}")
|
| 284 |
+
elif key == ord('f'):
|
| 285 |
+
self.show_face = not self.show_face
|
| 286 |
+
print(f"Face View: {'ON' if self.show_face else 'OFF'}")
|
| 287 |
+
# Crop adjustment keys
|
| 288 |
+
elif key == ord('u'): # Move top up
|
| 289 |
+
self.crop_top = max(0.0, self.crop_top - 0.05)
|
| 290 |
+
print(f"Crop top: {self.crop_top:.2f}")
|
| 291 |
+
elif key == ord('j'): # Move top down
|
| 292 |
+
self.crop_top = min(self.crop_bottom - 0.1, self.crop_top + 0.05)
|
| 293 |
+
print(f"Crop top: {self.crop_top:.2f}")
|
| 294 |
+
elif key == ord('i'): # Move bottom up
|
| 295 |
+
self.crop_bottom = max(self.crop_top + 0.1, self.crop_bottom - 0.05)
|
| 296 |
+
print(f"Crop bottom: {self.crop_bottom:.2f}")
|
| 297 |
+
elif key == ord('k'): # Move bottom down
|
| 298 |
+
self.crop_bottom = min(1.0, self.crop_bottom + 0.05)
|
| 299 |
+
print(f"Crop bottom: {self.crop_bottom:.2f}")
|
| 300 |
+
elif key == ord('o'): # Decrease side crop
|
| 301 |
+
self.crop_sides = max(0.0, self.crop_sides - 0.05)
|
| 302 |
+
print(f"Crop sides: {self.crop_sides:.2f}")
|
| 303 |
+
elif key == ord('l'): # Increase side crop
|
| 304 |
+
self.crop_sides = min(0.4, self.crop_sides + 0.05)
|
| 305 |
+
print(f"Crop sides: {self.crop_sides:.2f}")
|
| 306 |
+
|
| 307 |
+
cap.release()
|
| 308 |
+
cv2.destroyAllWindows()
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
import argparse
|
| 312 |
+
|
| 313 |
+
parser = argparse.ArgumentParser()
|
| 314 |
+
parser.add_argument('--model', type=str, required=True,
|
| 315 |
+
help='Path to trained gaze estimation model')
|
| 316 |
+
args = parser.parse_args()
|
| 317 |
+
|
| 318 |
+
gaze_inference = GazeInference(args.model)
|
| 319 |
+
gaze_inference.run()
|
training_faces.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow import keras
|
| 3 |
+
from tensorflow.keras import layers
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import cv2
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
# Set memory growth for GPU
|
| 12 |
+
gpus = tf.config.experimental.list_physical_devices('GPU')
|
| 13 |
+
if gpus:
|
| 14 |
+
try:
|
| 15 |
+
for gpu in gpus:
|
| 16 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
| 17 |
+
except RuntimeError as e:
|
| 18 |
+
print(e)
|
| 19 |
+
|
| 20 |
+
class ImprovedGazeModel:
|
| 21 |
+
def __init__(self, input_shape=(60, 80, 3)):
|
| 22 |
+
self.input_shape = input_shape
|
| 23 |
+
self.model = None
|
| 24 |
+
|
| 25 |
+
def build_simple_model(self):
|
| 26 |
+
"""Build a simpler, more effective model."""
|
| 27 |
+
inputs = keras.Input(shape=self.input_shape)
|
| 28 |
+
|
| 29 |
+
# First conv block
|
| 30 |
+
x = layers.Conv2D(16, (5, 5), padding='same')(inputs)
|
| 31 |
+
x = layers.BatchNormalization()(x)
|
| 32 |
+
x = layers.Activation('relu')(x)
|
| 33 |
+
x = layers.MaxPooling2D((2, 2))(x)
|
| 34 |
+
|
| 35 |
+
# Second conv block
|
| 36 |
+
x = layers.Conv2D(32, (3, 3), padding='same')(x)
|
| 37 |
+
x = layers.BatchNormalization()(x)
|
| 38 |
+
x = layers.Activation('relu')(x)
|
| 39 |
+
x = layers.MaxPooling2D((2, 2))(x)
|
| 40 |
+
|
| 41 |
+
# Third conv block
|
| 42 |
+
x = layers.Conv2D(64, (3, 3), padding='same')(x)
|
| 43 |
+
x = layers.BatchNormalization()(x)
|
| 44 |
+
x = layers.Activation('relu')(x)
|
| 45 |
+
x = layers.MaxPooling2D((2, 2))(x)
|
| 46 |
+
|
| 47 |
+
# Flatten and dense layers
|
| 48 |
+
x = layers.Flatten()(x)
|
| 49 |
+
x = layers.Dense(128)(x)
|
| 50 |
+
x = layers.BatchNormalization()(x)
|
| 51 |
+
x = layers.Activation('relu')(x)
|
| 52 |
+
x = layers.Dropout(0.3)(x)
|
| 53 |
+
|
| 54 |
+
x = layers.Dense(64)(x)
|
| 55 |
+
x = layers.BatchNormalization()(x)
|
| 56 |
+
x = layers.Activation('relu')(x)
|
| 57 |
+
x = layers.Dropout(0.3)(x)
|
| 58 |
+
|
| 59 |
+
# Output layer - no activation (linear regression)
|
| 60 |
+
outputs = layers.Dense(2)(x)
|
| 61 |
+
|
| 62 |
+
self.model = keras.Model(inputs, outputs, name='gaze_model')
|
| 63 |
+
return self.model
|
| 64 |
+
|
| 65 |
+
def compile_model(self, learning_rate=0.0001):
|
| 66 |
+
"""Compile with better optimizer settings."""
|
| 67 |
+
self.model.compile(
|
| 68 |
+
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
|
| 69 |
+
loss='mse',
|
| 70 |
+
metrics=['mae']
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def load_preprocessed_data(data_dir):
|
| 74 |
+
"""Load the preprocessed face dataset."""
|
| 75 |
+
data_dir = Path(data_dir)
|
| 76 |
+
|
| 77 |
+
# Load metadata
|
| 78 |
+
with open(data_dir / 'metadata.json', 'r') as f:
|
| 79 |
+
metadata = json.load(f)
|
| 80 |
+
|
| 81 |
+
screen_width = metadata['screen_width']
|
| 82 |
+
screen_height = metadata['screen_height']
|
| 83 |
+
data_points = metadata['data_points']
|
| 84 |
+
|
| 85 |
+
print(f"Loading {len(data_points)} data points...")
|
| 86 |
+
print(f"Screen dimensions: {screen_width}x{screen_height}")
|
| 87 |
+
|
| 88 |
+
# Load images and gaze coordinates
|
| 89 |
+
images = []
|
| 90 |
+
gaze_coords = []
|
| 91 |
+
|
| 92 |
+
print("Loading images...")
|
| 93 |
+
for i, point in enumerate(data_points):
|
| 94 |
+
if i % 500 == 0:
|
| 95 |
+
print(f"Progress: {i}/{len(data_points)}")
|
| 96 |
+
|
| 97 |
+
img_path = data_dir / 'images' / point['image']
|
| 98 |
+
if img_path.exists():
|
| 99 |
+
img = cv2.imread(str(img_path))
|
| 100 |
+
if img is not None:
|
| 101 |
+
# Convert to RGB and normalize
|
| 102 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 103 |
+
img = img.astype(np.float32) / 255.0
|
| 104 |
+
images.append(img)
|
| 105 |
+
|
| 106 |
+
# Normalize gaze coordinates to [0, 1]
|
| 107 |
+
norm_x = point['screen_x'] / screen_width
|
| 108 |
+
norm_y = point['screen_y'] / screen_height
|
| 109 |
+
gaze_coords.append([norm_x, norm_y])
|
| 110 |
+
|
| 111 |
+
images = np.array(images)
|
| 112 |
+
gaze_coords = np.array(gaze_coords)
|
| 113 |
+
|
| 114 |
+
print(f"\nSuccessfully loaded {len(images)} images")
|
| 115 |
+
print(f"Image shape: {images[0].shape}")
|
| 116 |
+
print(f"Gaze X range: [{gaze_coords[:, 0].min():.3f}, {gaze_coords[:, 0].max():.3f}]")
|
| 117 |
+
print(f"Gaze Y range: [{gaze_coords[:, 1].min():.3f}, {gaze_coords[:, 1].max():.3f}]")
|
| 118 |
+
|
| 119 |
+
# Check for any outliers
|
| 120 |
+
x_outliers = np.sum((gaze_coords[:, 0] < 0) | (gaze_coords[:, 0] > 1))
|
| 121 |
+
y_outliers = np.sum((gaze_coords[:, 1] < 0) | (gaze_coords[:, 1] > 1))
|
| 122 |
+
if x_outliers > 0 or y_outliers > 0:
|
| 123 |
+
print(f"WARNING: Found {x_outliers} X outliers and {y_outliers} Y outliers")
|
| 124 |
+
# Clip to valid range
|
| 125 |
+
gaze_coords = np.clip(gaze_coords, 0, 1)
|
| 126 |
+
|
| 127 |
+
return images, gaze_coords, screen_width, screen_height
|
| 128 |
+
|
| 129 |
+
def create_augmented_generator(X, y, batch_size=32, augment=True):
|
| 130 |
+
"""Create a generator with augmentation."""
|
| 131 |
+
n_samples = len(X)
|
| 132 |
+
indices = np.arange(n_samples)
|
| 133 |
+
|
| 134 |
+
while True:
|
| 135 |
+
np.random.shuffle(indices)
|
| 136 |
+
|
| 137 |
+
for start in range(0, n_samples, batch_size):
|
| 138 |
+
end = min(start + batch_size, n_samples)
|
| 139 |
+
batch_indices = indices[start:end]
|
| 140 |
+
|
| 141 |
+
batch_X = X[batch_indices].copy()
|
| 142 |
+
batch_y = y[batch_indices].copy()
|
| 143 |
+
|
| 144 |
+
if augment:
|
| 145 |
+
for i in range(len(batch_X)):
|
| 146 |
+
# Random brightness adjustment
|
| 147 |
+
if np.random.random() > 0.5:
|
| 148 |
+
brightness = np.random.uniform(0.7, 1.3)
|
| 149 |
+
batch_X[i] = np.clip(batch_X[i] * brightness, 0, 1)
|
| 150 |
+
|
| 151 |
+
# Random contrast adjustment
|
| 152 |
+
if np.random.random() > 0.5:
|
| 153 |
+
contrast = np.random.uniform(0.8, 1.2)
|
| 154 |
+
mean = batch_X[i].mean()
|
| 155 |
+
batch_X[i] = np.clip((batch_X[i] - mean) * contrast + mean, 0, 1)
|
| 156 |
+
|
| 157 |
+
# Small random noise
|
| 158 |
+
if np.random.random() > 0.5:
|
| 159 |
+
noise = np.random.normal(0, 0.01, batch_X[i].shape)
|
| 160 |
+
batch_X[i] = np.clip(batch_X[i] + noise, 0, 1)
|
| 161 |
+
|
| 162 |
+
yield batch_X, batch_y
|
| 163 |
+
|
| 164 |
+
def visualize_data_distribution(gaze_coords, screen_width, screen_height, save_path='gaze_distribution.png'):
|
| 165 |
+
"""Visualize the distribution of gaze points."""
|
| 166 |
+
plt.figure(figsize=(12, 6))
|
| 167 |
+
|
| 168 |
+
# Denormalize for visualization
|
| 169 |
+
x_pixels = gaze_coords[:, 0] * screen_width
|
| 170 |
+
y_pixels = gaze_coords[:, 1] * screen_height
|
| 171 |
+
|
| 172 |
+
# 2D histogram
|
| 173 |
+
plt.subplot(1, 2, 1)
|
| 174 |
+
plt.hist2d(x_pixels, y_pixels, bins=50, cmap='hot')
|
| 175 |
+
plt.colorbar(label='Count')
|
| 176 |
+
plt.xlabel('X (pixels)')
|
| 177 |
+
plt.ylabel('Y (pixels)')
|
| 178 |
+
plt.title('Gaze Point Distribution')
|
| 179 |
+
plt.gca().invert_yaxis() # Invert Y axis to match screen coordinates
|
| 180 |
+
|
| 181 |
+
# 1D distributions
|
| 182 |
+
plt.subplot(1, 2, 2)
|
| 183 |
+
plt.hist(x_pixels, bins=50, alpha=0.5, label='X distribution', density=True)
|
| 184 |
+
plt.hist(y_pixels, bins=50, alpha=0.5, label='Y distribution', density=True)
|
| 185 |
+
plt.xlabel('Position (pixels)')
|
| 186 |
+
plt.ylabel('Density')
|
| 187 |
+
plt.title('X and Y Distributions')
|
| 188 |
+
plt.legend()
|
| 189 |
+
|
| 190 |
+
plt.tight_layout()
|
| 191 |
+
plt.savefig(save_path)
|
| 192 |
+
plt.close()
|
| 193 |
+
|
| 194 |
+
def main():
|
| 195 |
+
import argparse
|
| 196 |
+
|
| 197 |
+
parser = argparse.ArgumentParser(description='Train improved gaze model')
|
| 198 |
+
parser.add_argument('--data', type=str, default='gaze_data_faces',
|
| 199 |
+
help='Preprocessed face dataset directory')
|
| 200 |
+
parser.add_argument('--epochs', type=int, default=100,
|
| 201 |
+
help='Number of training epochs')
|
| 202 |
+
parser.add_argument('--batch-size', type=int, default=32,
|
| 203 |
+
help='Batch size for training')
|
| 204 |
+
parser.add_argument('--lr', type=float, default=0.0001,
|
| 205 |
+
help='Learning rate')
|
| 206 |
+
|
| 207 |
+
args = parser.parse_args()
|
| 208 |
+
|
| 209 |
+
# Load data
|
| 210 |
+
images, gaze_coords, screen_width, screen_height = load_preprocessed_data(args.data)
|
| 211 |
+
|
| 212 |
+
# Visualize data distribution
|
| 213 |
+
visualize_data_distribution(gaze_coords, screen_width, screen_height)
|
| 214 |
+
|
| 215 |
+
# Split data
|
| 216 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 217 |
+
images, gaze_coords, test_size=0.15, random_state=42, shuffle=True
|
| 218 |
+
)
|
| 219 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 220 |
+
X_temp, y_temp, test_size=0.176, random_state=42, shuffle=True # 0.176 ≈ 0.15/(1-0.15)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
print(f"\nDataset splits:")
|
| 224 |
+
print(f"Training: {len(X_train)} samples")
|
| 225 |
+
print(f"Validation: {len(X_val)} samples")
|
| 226 |
+
print(f"Test: {len(X_test)} samples")
|
| 227 |
+
|
| 228 |
+
# Build and compile model
|
| 229 |
+
model = ImprovedGazeModel(input_shape=X_train.shape[1:])
|
| 230 |
+
model.build_simple_model()
|
| 231 |
+
model.compile_model(learning_rate=args.lr)
|
| 232 |
+
|
| 233 |
+
print("\nModel architecture:")
|
| 234 |
+
model.model.summary()
|
| 235 |
+
|
| 236 |
+
# Create generators
|
| 237 |
+
train_gen = create_augmented_generator(X_train, y_train, args.batch_size, augment=True)
|
| 238 |
+
val_gen = create_augmented_generator(X_val, y_val, args.batch_size, augment=False)
|
| 239 |
+
|
| 240 |
+
steps_per_epoch = len(X_train) // args.batch_size
|
| 241 |
+
validation_steps = len(X_val) // args.batch_size
|
| 242 |
+
|
| 243 |
+
# Callbacks
|
| 244 |
+
callbacks = [
|
| 245 |
+
keras.callbacks.ModelCheckpoint(
|
| 246 |
+
'best_gaze_model_improved.keras',
|
| 247 |
+
monitor='val_loss',
|
| 248 |
+
save_best_only=True,
|
| 249 |
+
verbose=1
|
| 250 |
+
),
|
| 251 |
+
keras.callbacks.EarlyStopping(
|
| 252 |
+
monitor='val_loss',
|
| 253 |
+
patience=20,
|
| 254 |
+
restore_best_weights=True,
|
| 255 |
+
verbose=1
|
| 256 |
+
),
|
| 257 |
+
keras.callbacks.ReduceLROnPlateau(
|
| 258 |
+
monitor='val_loss',
|
| 259 |
+
factor=0.5,
|
| 260 |
+
patience=10,
|
| 261 |
+
min_lr=1e-7,
|
| 262 |
+
verbose=1
|
| 263 |
+
),
|
| 264 |
+
keras.callbacks.CSVLogger('training_log.csv')
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
# Train
|
| 268 |
+
print("\nStarting training...")
|
| 269 |
+
history = model.model.fit(
|
| 270 |
+
train_gen,
|
| 271 |
+
steps_per_epoch=steps_per_epoch,
|
| 272 |
+
validation_data=val_gen,
|
| 273 |
+
validation_steps=validation_steps,
|
| 274 |
+
epochs=args.epochs,
|
| 275 |
+
callbacks=callbacks,
|
| 276 |
+
verbose=1
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Evaluate on test set
|
| 280 |
+
print("\nEvaluating on test set...")
|
| 281 |
+
test_loss, test_mae = model.model.evaluate(X_test, y_test, batch_size=args.batch_size)
|
| 282 |
+
|
| 283 |
+
# Get predictions
|
| 284 |
+
predictions = model.model.predict(X_test, batch_size=args.batch_size)
|
| 285 |
+
|
| 286 |
+
# Calculate pixel errors
|
| 287 |
+
pred_pixels = predictions * np.array([screen_width, screen_height])
|
| 288 |
+
actual_pixels = y_test * np.array([screen_width, screen_height])
|
| 289 |
+
|
| 290 |
+
pixel_errors = np.abs(pred_pixels - actual_pixels)
|
| 291 |
+
euclidean_errors = np.sqrt(np.sum((pred_pixels - actual_pixels)**2, axis=1))
|
| 292 |
+
|
| 293 |
+
print(f"\nTest Results:")
|
| 294 |
+
print(f"Loss: {test_loss:.6f}")
|
| 295 |
+
print(f"MAE (normalized): {test_mae:.6f}")
|
| 296 |
+
print(f"Mean X error: {pixel_errors[:, 0].mean():.1f} pixels")
|
| 297 |
+
print(f"Mean Y error: {pixel_errors[:, 1].mean():.1f} pixels")
|
| 298 |
+
print(f"Mean Euclidean error: {euclidean_errors.mean():.1f} pixels")
|
| 299 |
+
print(f"Median Euclidean error: {np.median(euclidean_errors):.1f} pixels")
|
| 300 |
+
print(f"95th percentile error: {np.percentile(euclidean_errors, 95):.1f} pixels")
|
| 301 |
+
|
| 302 |
+
# Plot training history
|
| 303 |
+
plt.figure(figsize=(12, 4))
|
| 304 |
+
|
| 305 |
+
plt.subplot(1, 2, 1)
|
| 306 |
+
plt.plot(history.history['loss'], label='Training Loss')
|
| 307 |
+
plt.plot(history.history['val_loss'], label='Validation Loss')
|
| 308 |
+
plt.xlabel('Epoch')
|
| 309 |
+
plt.ylabel('Loss')
|
| 310 |
+
plt.title('Model Loss')
|
| 311 |
+
plt.legend()
|
| 312 |
+
plt.yscale('log')
|
| 313 |
+
|
| 314 |
+
plt.subplot(1, 2, 2)
|
| 315 |
+
plt.plot(history.history['mae'], label='Training MAE')
|
| 316 |
+
plt.plot(history.history['val_mae'], label='Validation MAE')
|
| 317 |
+
plt.xlabel('Epoch')
|
| 318 |
+
plt.ylabel('MAE')
|
| 319 |
+
plt.title('Model MAE')
|
| 320 |
+
plt.legend()
|
| 321 |
+
|
| 322 |
+
plt.tight_layout()
|
| 323 |
+
plt.savefig('improved_training_history.png')
|
| 324 |
+
plt.close()
|
| 325 |
+
|
| 326 |
+
# Save configuration
|
| 327 |
+
config = {
|
| 328 |
+
'model_path': 'best_gaze_model_improved.keras',
|
| 329 |
+
'input_shape': list(model.input_shape),
|
| 330 |
+
'screen_width': int(screen_width),
|
| 331 |
+
'screen_height': int(screen_height),
|
| 332 |
+
'test_loss': float(test_loss),
|
| 333 |
+
'test_mae': float(test_mae),
|
| 334 |
+
'mean_euclidean_error': float(euclidean_errors.mean()),
|
| 335 |
+
'preprocessing': {
|
| 336 |
+
'crop_top': 0.25,
|
| 337 |
+
'crop_bottom': 0.55,
|
| 338 |
+
'crop_sides': 0.15
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
with open('model_config_improved.json', 'w') as f:
|
| 343 |
+
json.dump(config, f, indent=2)
|
| 344 |
+
|
| 345 |
+
print(f"\nModel saved to: best_gaze_model_improved.keras")
|
| 346 |
+
print(f"Config saved to: model_config_improved.json")
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|