Olof Astrand commited on
Commit
c17a15c
·
1 Parent(s): 0eda6ef

Minor updates

Browse files
Files changed (3) hide show
  1. README.md +51 -2
  2. inference_claude.py +319 -0
  3. training_faces.py +349 -0
README.md CHANGED
@@ -13,8 +13,11 @@ Creating a dataset
13
  ==================
14
  collector.py
15
  collector.html
16
- When creating a dataset in the browser you will have to convert ti with
17
  convert.py
 
 
 
18
 
19
 
20
  Training from web based dataset
@@ -23,10 +26,56 @@ training.py
23
 
24
  Training from OpenCV created dataset
25
  ==============
 
 
26
  training_deepseek.py
27
 
28
 
29
  Inference
30
  ==========
31
  inference.py
32
- This does not work in a wsl environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ==================
14
  collector.py
15
  collector.html
16
+ When creating a dataset in the browser you will have to convert it with
17
  convert.py
18
+ When creating dataset with collector you have to preprocess it with
19
+ preprocess.py
20
+
21
 
22
 
23
  Training from web based dataset
 
26
 
27
  Training from OpenCV created dataset
28
  ==============
29
+ training_faces.py
30
+
31
  training_deepseek.py
32
 
33
 
34
  Inference
35
  ==========
36
  inference.py
37
+ This does not work in a wsl environment as we cannot access the camera.
38
+
39
+
40
+ Mobilenet used
41
+ Model architecture:
42
+ Model: "functional"
43
+ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
44
+ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃
45
+ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
46
+ │ input_layer (InputLayer) │ (None, 60, 80, 3) │ 0 │
47
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
48
+ │ conv2d (Conv2D) │ (None, 60, 80, 32) │ 896 │
49
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
50
+ │ max_pooling2d (MaxPooling2D) │ (None, 30, 40, 32) │ 0 │
51
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
52
+ │ dropout (Dropout) │ (None, 30, 40, 32) │ 0 │
53
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
54
+ │ conv2d_1 (Conv2D) │ (None, 30, 40, 64) │ 18,496 │
55
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
56
+ │ max_pooling2d_1 (MaxPooling2D) │ (None, 15, 20, 64) │ 0 │
57
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
58
+ │ dropout_1 (Dropout) │ (None, 15, 20, 64) │ 0 │
59
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
60
+ │ conv2d_2 (Conv2D) │ (None, 15, 20, 128) │ 73,856 │
61
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
62
+ │ max_pooling2d_2 (MaxPooling2D) │ (None, 7, 10, 128) │ 0 │
63
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
64
+ │ dropout_2 (Dropout) │ (None, 7, 10, 128) │ 0 │
65
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
66
+ │ global_average_pooling2d │ (None, 128) │ 0 │
67
+ │ (GlobalAveragePooling2D) │ │ │
68
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
69
+ │ dense (Dense) │ (None, 128) │ 16,512 │
70
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
71
+ │ dropout_3 (Dropout) │ (None, 128) │ 0 │
72
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
73
+ │ dense_1 (Dense) │ (None, 64) │ 8,256 │
74
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
75
+ │ dropout_4 (Dropout) │ (None, 64) │ 0 │
76
+ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
77
+ │ dense_2 (Dense) │ (None, 2) │ 130 │
78
+ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
79
+ Total params: 118,146 (461.51 KB)
80
+ Trainable params: 118,146 (461.51 KB)
81
+
inference_claude.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ import tkinter as tk
6
+
7
+ class GazeInference:
8
+ def __init__(self, model_path):
9
+ # Load model
10
+ self.model = keras.models.load_model(model_path)
11
+
12
+ # Initialize face cascade classifier
13
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
14
+
15
+ # Get actual screen resolution using tkinter
16
+ self._get_screen_resolution()
17
+
18
+ # Gaze smoothing
19
+ self.smoothing = True
20
+ self.smoothing_factor = 0.2
21
+ self.last_gaze = None
22
+
23
+ # Face visualization
24
+ self.show_face = False
25
+ self.face_roi = None
26
+ self.upper_face_roi = None
27
+
28
+ # Face detection parameters
29
+ self.min_face_size = (50, 50)
30
+ self.scale_factor = 1.1
31
+ self.min_neighbors = 5
32
+
33
+ # Adjustable crop parameters (can be modified with keys)
34
+ self.crop_top = 0.05 # Start at 25% from top of face
35
+ self.crop_bottom = 0.80 # End at 55% from top of face
36
+ self.crop_sides = 0.05 # Crop 15% from each side
37
+
38
+ def _get_screen_resolution(self):
39
+ """Get primary screen resolution using tkinter"""
40
+ root = tk.Tk()
41
+ self.screen_width = root.winfo_screenwidth()
42
+ self.screen_height = root.winfo_screenheight()
43
+ root.destroy()
44
+ print(f"Screen resolution: {self.screen_width}x{self.screen_height}")
45
+
46
+ def _extract_upper_face_region(self, frame):
47
+ """Extract upper half of face region from frame"""
48
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
49
+
50
+ # Detect faces
51
+ faces = self.face_cascade.detectMultiScale(
52
+ gray,
53
+ scaleFactor=self.scale_factor,
54
+ minNeighbors=self.min_neighbors,
55
+ minSize=self.min_face_size
56
+ )
57
+
58
+ if len(faces) > 0:
59
+ # Get the largest face (assuming it's the closest/main face)
60
+ face = max(faces, key=lambda x: x[2] * x[3])
61
+ x, y, w, h = face
62
+
63
+ # Store full face for visualization
64
+ face_padding = 10
65
+ x_start = max(0, x - face_padding)
66
+ y_start = max(0, y - face_padding)
67
+ x_end = min(frame.shape[1], x + w + face_padding)
68
+ y_end = min(frame.shape[0], y + h + face_padding)
69
+ self.face_roi = frame[y_start:y_end, x_start:x_end].copy()
70
+
71
+ # Extract tight crop around eyes region to match training data
72
+ # Based on typical face proportions:
73
+ # - Eyes are typically at 40-50% down from top of face
74
+ # - Eye region height is about 20-30% of face height
75
+
76
+ # Use adjustable crop parameters
77
+ eye_region_start = int(h * self.crop_top)
78
+ eye_region_end = int(h * self.crop_bottom)
79
+
80
+ # For width, focus on central portion of face
81
+ width_crop = int(w * self.crop_sides)
82
+
83
+ # Calculate bounds for eye region
84
+ uf_x_start = max(0, x + width_crop)
85
+ uf_y_start = max(0, y + eye_region_start)
86
+ uf_x_end = min(frame.shape[1], x + w - width_crop)
87
+ uf_y_end = min(frame.shape[0], y + eye_region_end)
88
+
89
+ # Extract and resize upper face region
90
+ upper_face = frame[uf_y_start:uf_y_end, uf_x_start:uf_x_end]
91
+
92
+ # Resize to model input size (80x60)
93
+ upper_face_resized = cv2.resize(upper_face, (80, 60))
94
+ self.upper_face_roi = upper_face_resized.copy()
95
+
96
+ return upper_face_resized
97
+
98
+ return None
99
+
100
+ def _predict_gaze(self, upper_face_region):
101
+ """Predict gaze position from upper face region"""
102
+ # Preprocess
103
+ face_input = cv2.cvtColor(upper_face_region, cv2.COLOR_BGR2RGB)
104
+ face_input = face_input.astype('float32') / 255.0
105
+ face_input = np.expand_dims(face_input, axis=0)
106
+
107
+ # Predict
108
+ pred = self.model.predict(face_input, verbose=0)[0]
109
+
110
+ # Convert to screen coordinates
111
+ screen_x = int(pred[0] * self.screen_width)
112
+ screen_y = int(pred[1] * self.screen_height)
113
+
114
+ # Fix left-right inversion: flip X coordinate
115
+ # Since camera is mirrored, we need to invert the X prediction
116
+ screen_x = self.screen_width - screen_x
117
+
118
+ # Clamp to screen bounds
119
+ screen_x = max(0, min(self.screen_width - 1, screen_x))
120
+ screen_y = max(0, min(self.screen_height - 1, screen_y))
121
+
122
+ # Apply smoothing if enabled
123
+ if self.smoothing and self.last_gaze is not None:
124
+ screen_x = int(self.smoothing_factor * screen_x +
125
+ (1 - self.smoothing_factor) * self.last_gaze[0])
126
+ screen_y = int(self.smoothing_factor * screen_y +
127
+ (1 - self.smoothing_factor) * self.last_gaze[1])
128
+
129
+ self.last_gaze = (screen_x, screen_y)
130
+ return screen_x, screen_y
131
+
132
+ def _draw_gaze_cross(self, frame, x, y):
133
+ """Draw crosshair at gaze position"""
134
+ color = (0, 255, 0) # Green
135
+ size = 30
136
+ thickness = 3
137
+
138
+ # Horizontal line
139
+ cv2.line(frame, (x - size, y), (x + size, y), color, thickness)
140
+ # Vertical line
141
+ cv2.line(frame, (x, y - size), (x, y + size), color, thickness)
142
+ # Center circle
143
+ cv2.circle(frame, (x, y), 5, color, -1)
144
+
145
+ def _draw_face_roi(self, frame):
146
+ """Draw face region visualization in bottom left"""
147
+ if self.show_face and self.face_roi is not None and self.upper_face_roi is not None:
148
+ # Calculate display sizes
149
+ max_height = 200
150
+
151
+ # Display full face
152
+ face_h, face_w = self.face_roi.shape[:2]
153
+ display_h = min(max_height, self.screen_height // 4)
154
+ display_w = int(display_h * (face_w / face_h))
155
+
156
+ # Ensure we don't exceed screen dimensions
157
+ display_w = min(display_w, self.screen_width // 3)
158
+ display_h = min(display_h, self.screen_height // 3)
159
+
160
+ # Resize face for display
161
+ face_display = cv2.resize(self.face_roi, (display_w, display_h))
162
+
163
+ # Position for full face (bottom left)
164
+ face_y = self.screen_height - display_h - 10
165
+ face_x = 10
166
+
167
+ # Draw full face
168
+ try:
169
+ frame[face_y:face_y + display_h, face_x:face_x + display_w] = face_display
170
+ cv2.rectangle(frame, (face_x, face_y),
171
+ (face_x + display_w, face_y + display_h), (255, 255, 255), 2)
172
+ cv2.putText(frame, "Full Face", (face_x + 5, face_y - 5),
173
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
174
+ except:
175
+ pass
176
+
177
+ # Display upper face region
178
+ uf_h = int(display_h * 0.75) # 60:80 aspect ratio
179
+ uf_w = display_w
180
+ upper_face_display = cv2.resize(self.upper_face_roi, (uf_w, uf_h))
181
+
182
+ # Position for upper face (next to full face)
183
+ uf_x = face_x + display_w + 20
184
+ uf_y = self.screen_height - uf_h - 10
185
+
186
+ # Draw upper face
187
+ try:
188
+ frame[uf_y:uf_y + uf_h, uf_x:uf_x + uf_w] = upper_face_display
189
+ cv2.rectangle(frame, (uf_x, uf_y),
190
+ (uf_x + uf_w, uf_y + uf_h), (0, 255, 255), 2)
191
+ cv2.putText(frame, "Upper Face (Model Input)", (uf_x + 5, uf_y - 5),
192
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
193
+ except:
194
+ pass
195
+
196
+ def run(self):
197
+ """Main inference loop"""
198
+ cap = cv2.VideoCapture(0)
199
+ if not cap.isOpened():
200
+ print("Error: Could not open webcam")
201
+ return
202
+
203
+ # Create fullscreen window
204
+ cv2.namedWindow('Gaze Prediction', cv2.WND_PROP_FULLSCREEN)
205
+ cv2.setWindowProperty('Gaze Prediction', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
206
+
207
+ print("Controls:")
208
+ print("'s': Toggle Smoothing")
209
+ print("'f': Toggle Face View")
210
+ print("'q': Quit")
211
+ print("Crop Adjustment:")
212
+ print("'u'/'j': Move crop top up/down")
213
+ print("'i'/'k': Move crop bottom up/down")
214
+ print("'o'/'l': Decrease/increase side crop")
215
+ print(f"Using screen resolution: {self.screen_width}x{self.screen_height}")
216
+
217
+ while True:
218
+ ret, frame = cap.read()
219
+ if not ret:
220
+ break
221
+
222
+ # Create black canvas matching screen size
223
+ canvas = np.zeros((self.screen_height, self.screen_width, 3), dtype=np.uint8)
224
+
225
+ # Mirror the frame for more natural interaction
226
+ frame = cv2.flip(frame, 1)
227
+
228
+ # Process frame
229
+ upper_face = self._extract_upper_face_region(frame)
230
+
231
+ if upper_face is not None:
232
+ # Predict gaze
233
+ gaze_x, gaze_y = self._predict_gaze(upper_face)
234
+
235
+ # Draw gaze cross on canvas
236
+ self._draw_gaze_cross(canvas, gaze_x, gaze_y)
237
+
238
+ # Show coordinates for debugging
239
+ cv2.putText(canvas, f"Gaze: ({gaze_x}, {gaze_y})",
240
+ (20, self.screen_height - 30), cv2.FONT_HERSHEY_SIMPLEX,
241
+ 0.7, (255, 255, 0), 2)
242
+
243
+ # Draw face regions if enabled
244
+ if self.show_face:
245
+ self._draw_face_roi(canvas)
246
+ else:
247
+ # Show "no face detected" message
248
+ cv2.putText(canvas, "No face detected",
249
+ (20, self.screen_height - 30), cv2.FONT_HERSHEY_SIMPLEX,
250
+ 0.7, (0, 0, 255), 2)
251
+
252
+ # Show instructions
253
+ cv2.putText(canvas, "'s': Smoothing | 'f': Face View | 'u/j': Top | 'i/k': Bottom | 'o/l': Sides | 'q': Quit",
254
+ (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
255
+
256
+ # Show smoothing status
257
+ status = "ON" if self.smoothing else "OFF"
258
+ cv2.putText(canvas, f"Smoothing: {status}",
259
+ (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
260
+ (0, 255, 0) if self.smoothing else (0, 0, 255), 2)
261
+
262
+ # Show face view status
263
+ status = "ON" if self.show_face else "OFF"
264
+ cv2.putText(canvas, f"Face View: {status}",
265
+ (20, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
266
+ (0, 255, 0) if self.show_face else (0, 0, 255), 2)
267
+
268
+ # Show screen resolution info and crop parameters
269
+ cv2.putText(canvas, f"Screen: {self.screen_width}x{self.screen_height}",
270
+ (20, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
271
+ cv2.putText(canvas, f"Crop: Top={self.crop_top:.2f} Bottom={self.crop_bottom:.2f} Sides={self.crop_sides:.2f}",
272
+ (20, 190), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 2)
273
+
274
+ # Display
275
+ cv2.imshow('Gaze Prediction', canvas)
276
+
277
+ # Handle key presses
278
+ key = cv2.waitKey(1) & 0xFF
279
+ if key == ord('q'):
280
+ break
281
+ elif key == ord('s'):
282
+ self.smoothing = not self.smoothing
283
+ print(f"Smoothing: {'ON' if self.smoothing else 'OFF'}")
284
+ elif key == ord('f'):
285
+ self.show_face = not self.show_face
286
+ print(f"Face View: {'ON' if self.show_face else 'OFF'}")
287
+ # Crop adjustment keys
288
+ elif key == ord('u'): # Move top up
289
+ self.crop_top = max(0.0, self.crop_top - 0.05)
290
+ print(f"Crop top: {self.crop_top:.2f}")
291
+ elif key == ord('j'): # Move top down
292
+ self.crop_top = min(self.crop_bottom - 0.1, self.crop_top + 0.05)
293
+ print(f"Crop top: {self.crop_top:.2f}")
294
+ elif key == ord('i'): # Move bottom up
295
+ self.crop_bottom = max(self.crop_top + 0.1, self.crop_bottom - 0.05)
296
+ print(f"Crop bottom: {self.crop_bottom:.2f}")
297
+ elif key == ord('k'): # Move bottom down
298
+ self.crop_bottom = min(1.0, self.crop_bottom + 0.05)
299
+ print(f"Crop bottom: {self.crop_bottom:.2f}")
300
+ elif key == ord('o'): # Decrease side crop
301
+ self.crop_sides = max(0.0, self.crop_sides - 0.05)
302
+ print(f"Crop sides: {self.crop_sides:.2f}")
303
+ elif key == ord('l'): # Increase side crop
304
+ self.crop_sides = min(0.4, self.crop_sides + 0.05)
305
+ print(f"Crop sides: {self.crop_sides:.2f}")
306
+
307
+ cap.release()
308
+ cv2.destroyAllWindows()
309
+
310
+ if __name__ == "__main__":
311
+ import argparse
312
+
313
+ parser = argparse.ArgumentParser()
314
+ parser.add_argument('--model', type=str, required=True,
315
+ help='Path to trained gaze estimation model')
316
+ args = parser.parse_args()
317
+
318
+ gaze_inference = GazeInference(args.model)
319
+ gaze_inference.run()
training_faces.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ from tensorflow.keras import layers
4
+ import numpy as np
5
+ import json
6
+ import cv2
7
+ from pathlib import Path
8
+ from sklearn.model_selection import train_test_split
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Set memory growth for GPU
12
+ gpus = tf.config.experimental.list_physical_devices('GPU')
13
+ if gpus:
14
+ try:
15
+ for gpu in gpus:
16
+ tf.config.experimental.set_memory_growth(gpu, True)
17
+ except RuntimeError as e:
18
+ print(e)
19
+
20
+ class ImprovedGazeModel:
21
+ def __init__(self, input_shape=(60, 80, 3)):
22
+ self.input_shape = input_shape
23
+ self.model = None
24
+
25
+ def build_simple_model(self):
26
+ """Build a simpler, more effective model."""
27
+ inputs = keras.Input(shape=self.input_shape)
28
+
29
+ # First conv block
30
+ x = layers.Conv2D(16, (5, 5), padding='same')(inputs)
31
+ x = layers.BatchNormalization()(x)
32
+ x = layers.Activation('relu')(x)
33
+ x = layers.MaxPooling2D((2, 2))(x)
34
+
35
+ # Second conv block
36
+ x = layers.Conv2D(32, (3, 3), padding='same')(x)
37
+ x = layers.BatchNormalization()(x)
38
+ x = layers.Activation('relu')(x)
39
+ x = layers.MaxPooling2D((2, 2))(x)
40
+
41
+ # Third conv block
42
+ x = layers.Conv2D(64, (3, 3), padding='same')(x)
43
+ x = layers.BatchNormalization()(x)
44
+ x = layers.Activation('relu')(x)
45
+ x = layers.MaxPooling2D((2, 2))(x)
46
+
47
+ # Flatten and dense layers
48
+ x = layers.Flatten()(x)
49
+ x = layers.Dense(128)(x)
50
+ x = layers.BatchNormalization()(x)
51
+ x = layers.Activation('relu')(x)
52
+ x = layers.Dropout(0.3)(x)
53
+
54
+ x = layers.Dense(64)(x)
55
+ x = layers.BatchNormalization()(x)
56
+ x = layers.Activation('relu')(x)
57
+ x = layers.Dropout(0.3)(x)
58
+
59
+ # Output layer - no activation (linear regression)
60
+ outputs = layers.Dense(2)(x)
61
+
62
+ self.model = keras.Model(inputs, outputs, name='gaze_model')
63
+ return self.model
64
+
65
+ def compile_model(self, learning_rate=0.0001):
66
+ """Compile with better optimizer settings."""
67
+ self.model.compile(
68
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
69
+ loss='mse',
70
+ metrics=['mae']
71
+ )
72
+
73
+ def load_preprocessed_data(data_dir):
74
+ """Load the preprocessed face dataset."""
75
+ data_dir = Path(data_dir)
76
+
77
+ # Load metadata
78
+ with open(data_dir / 'metadata.json', 'r') as f:
79
+ metadata = json.load(f)
80
+
81
+ screen_width = metadata['screen_width']
82
+ screen_height = metadata['screen_height']
83
+ data_points = metadata['data_points']
84
+
85
+ print(f"Loading {len(data_points)} data points...")
86
+ print(f"Screen dimensions: {screen_width}x{screen_height}")
87
+
88
+ # Load images and gaze coordinates
89
+ images = []
90
+ gaze_coords = []
91
+
92
+ print("Loading images...")
93
+ for i, point in enumerate(data_points):
94
+ if i % 500 == 0:
95
+ print(f"Progress: {i}/{len(data_points)}")
96
+
97
+ img_path = data_dir / 'images' / point['image']
98
+ if img_path.exists():
99
+ img = cv2.imread(str(img_path))
100
+ if img is not None:
101
+ # Convert to RGB and normalize
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ img = img.astype(np.float32) / 255.0
104
+ images.append(img)
105
+
106
+ # Normalize gaze coordinates to [0, 1]
107
+ norm_x = point['screen_x'] / screen_width
108
+ norm_y = point['screen_y'] / screen_height
109
+ gaze_coords.append([norm_x, norm_y])
110
+
111
+ images = np.array(images)
112
+ gaze_coords = np.array(gaze_coords)
113
+
114
+ print(f"\nSuccessfully loaded {len(images)} images")
115
+ print(f"Image shape: {images[0].shape}")
116
+ print(f"Gaze X range: [{gaze_coords[:, 0].min():.3f}, {gaze_coords[:, 0].max():.3f}]")
117
+ print(f"Gaze Y range: [{gaze_coords[:, 1].min():.3f}, {gaze_coords[:, 1].max():.3f}]")
118
+
119
+ # Check for any outliers
120
+ x_outliers = np.sum((gaze_coords[:, 0] < 0) | (gaze_coords[:, 0] > 1))
121
+ y_outliers = np.sum((gaze_coords[:, 1] < 0) | (gaze_coords[:, 1] > 1))
122
+ if x_outliers > 0 or y_outliers > 0:
123
+ print(f"WARNING: Found {x_outliers} X outliers and {y_outliers} Y outliers")
124
+ # Clip to valid range
125
+ gaze_coords = np.clip(gaze_coords, 0, 1)
126
+
127
+ return images, gaze_coords, screen_width, screen_height
128
+
129
+ def create_augmented_generator(X, y, batch_size=32, augment=True):
130
+ """Create a generator with augmentation."""
131
+ n_samples = len(X)
132
+ indices = np.arange(n_samples)
133
+
134
+ while True:
135
+ np.random.shuffle(indices)
136
+
137
+ for start in range(0, n_samples, batch_size):
138
+ end = min(start + batch_size, n_samples)
139
+ batch_indices = indices[start:end]
140
+
141
+ batch_X = X[batch_indices].copy()
142
+ batch_y = y[batch_indices].copy()
143
+
144
+ if augment:
145
+ for i in range(len(batch_X)):
146
+ # Random brightness adjustment
147
+ if np.random.random() > 0.5:
148
+ brightness = np.random.uniform(0.7, 1.3)
149
+ batch_X[i] = np.clip(batch_X[i] * brightness, 0, 1)
150
+
151
+ # Random contrast adjustment
152
+ if np.random.random() > 0.5:
153
+ contrast = np.random.uniform(0.8, 1.2)
154
+ mean = batch_X[i].mean()
155
+ batch_X[i] = np.clip((batch_X[i] - mean) * contrast + mean, 0, 1)
156
+
157
+ # Small random noise
158
+ if np.random.random() > 0.5:
159
+ noise = np.random.normal(0, 0.01, batch_X[i].shape)
160
+ batch_X[i] = np.clip(batch_X[i] + noise, 0, 1)
161
+
162
+ yield batch_X, batch_y
163
+
164
+ def visualize_data_distribution(gaze_coords, screen_width, screen_height, save_path='gaze_distribution.png'):
165
+ """Visualize the distribution of gaze points."""
166
+ plt.figure(figsize=(12, 6))
167
+
168
+ # Denormalize for visualization
169
+ x_pixels = gaze_coords[:, 0] * screen_width
170
+ y_pixels = gaze_coords[:, 1] * screen_height
171
+
172
+ # 2D histogram
173
+ plt.subplot(1, 2, 1)
174
+ plt.hist2d(x_pixels, y_pixels, bins=50, cmap='hot')
175
+ plt.colorbar(label='Count')
176
+ plt.xlabel('X (pixels)')
177
+ plt.ylabel('Y (pixels)')
178
+ plt.title('Gaze Point Distribution')
179
+ plt.gca().invert_yaxis() # Invert Y axis to match screen coordinates
180
+
181
+ # 1D distributions
182
+ plt.subplot(1, 2, 2)
183
+ plt.hist(x_pixels, bins=50, alpha=0.5, label='X distribution', density=True)
184
+ plt.hist(y_pixels, bins=50, alpha=0.5, label='Y distribution', density=True)
185
+ plt.xlabel('Position (pixels)')
186
+ plt.ylabel('Density')
187
+ plt.title('X and Y Distributions')
188
+ plt.legend()
189
+
190
+ plt.tight_layout()
191
+ plt.savefig(save_path)
192
+ plt.close()
193
+
194
+ def main():
195
+ import argparse
196
+
197
+ parser = argparse.ArgumentParser(description='Train improved gaze model')
198
+ parser.add_argument('--data', type=str, default='gaze_data_faces',
199
+ help='Preprocessed face dataset directory')
200
+ parser.add_argument('--epochs', type=int, default=100,
201
+ help='Number of training epochs')
202
+ parser.add_argument('--batch-size', type=int, default=32,
203
+ help='Batch size for training')
204
+ parser.add_argument('--lr', type=float, default=0.0001,
205
+ help='Learning rate')
206
+
207
+ args = parser.parse_args()
208
+
209
+ # Load data
210
+ images, gaze_coords, screen_width, screen_height = load_preprocessed_data(args.data)
211
+
212
+ # Visualize data distribution
213
+ visualize_data_distribution(gaze_coords, screen_width, screen_height)
214
+
215
+ # Split data
216
+ X_temp, X_test, y_temp, y_test = train_test_split(
217
+ images, gaze_coords, test_size=0.15, random_state=42, shuffle=True
218
+ )
219
+ X_train, X_val, y_train, y_val = train_test_split(
220
+ X_temp, y_temp, test_size=0.176, random_state=42, shuffle=True # 0.176 ≈ 0.15/(1-0.15)
221
+ )
222
+
223
+ print(f"\nDataset splits:")
224
+ print(f"Training: {len(X_train)} samples")
225
+ print(f"Validation: {len(X_val)} samples")
226
+ print(f"Test: {len(X_test)} samples")
227
+
228
+ # Build and compile model
229
+ model = ImprovedGazeModel(input_shape=X_train.shape[1:])
230
+ model.build_simple_model()
231
+ model.compile_model(learning_rate=args.lr)
232
+
233
+ print("\nModel architecture:")
234
+ model.model.summary()
235
+
236
+ # Create generators
237
+ train_gen = create_augmented_generator(X_train, y_train, args.batch_size, augment=True)
238
+ val_gen = create_augmented_generator(X_val, y_val, args.batch_size, augment=False)
239
+
240
+ steps_per_epoch = len(X_train) // args.batch_size
241
+ validation_steps = len(X_val) // args.batch_size
242
+
243
+ # Callbacks
244
+ callbacks = [
245
+ keras.callbacks.ModelCheckpoint(
246
+ 'best_gaze_model_improved.keras',
247
+ monitor='val_loss',
248
+ save_best_only=True,
249
+ verbose=1
250
+ ),
251
+ keras.callbacks.EarlyStopping(
252
+ monitor='val_loss',
253
+ patience=20,
254
+ restore_best_weights=True,
255
+ verbose=1
256
+ ),
257
+ keras.callbacks.ReduceLROnPlateau(
258
+ monitor='val_loss',
259
+ factor=0.5,
260
+ patience=10,
261
+ min_lr=1e-7,
262
+ verbose=1
263
+ ),
264
+ keras.callbacks.CSVLogger('training_log.csv')
265
+ ]
266
+
267
+ # Train
268
+ print("\nStarting training...")
269
+ history = model.model.fit(
270
+ train_gen,
271
+ steps_per_epoch=steps_per_epoch,
272
+ validation_data=val_gen,
273
+ validation_steps=validation_steps,
274
+ epochs=args.epochs,
275
+ callbacks=callbacks,
276
+ verbose=1
277
+ )
278
+
279
+ # Evaluate on test set
280
+ print("\nEvaluating on test set...")
281
+ test_loss, test_mae = model.model.evaluate(X_test, y_test, batch_size=args.batch_size)
282
+
283
+ # Get predictions
284
+ predictions = model.model.predict(X_test, batch_size=args.batch_size)
285
+
286
+ # Calculate pixel errors
287
+ pred_pixels = predictions * np.array([screen_width, screen_height])
288
+ actual_pixels = y_test * np.array([screen_width, screen_height])
289
+
290
+ pixel_errors = np.abs(pred_pixels - actual_pixels)
291
+ euclidean_errors = np.sqrt(np.sum((pred_pixels - actual_pixels)**2, axis=1))
292
+
293
+ print(f"\nTest Results:")
294
+ print(f"Loss: {test_loss:.6f}")
295
+ print(f"MAE (normalized): {test_mae:.6f}")
296
+ print(f"Mean X error: {pixel_errors[:, 0].mean():.1f} pixels")
297
+ print(f"Mean Y error: {pixel_errors[:, 1].mean():.1f} pixels")
298
+ print(f"Mean Euclidean error: {euclidean_errors.mean():.1f} pixels")
299
+ print(f"Median Euclidean error: {np.median(euclidean_errors):.1f} pixels")
300
+ print(f"95th percentile error: {np.percentile(euclidean_errors, 95):.1f} pixels")
301
+
302
+ # Plot training history
303
+ plt.figure(figsize=(12, 4))
304
+
305
+ plt.subplot(1, 2, 1)
306
+ plt.plot(history.history['loss'], label='Training Loss')
307
+ plt.plot(history.history['val_loss'], label='Validation Loss')
308
+ plt.xlabel('Epoch')
309
+ plt.ylabel('Loss')
310
+ plt.title('Model Loss')
311
+ plt.legend()
312
+ plt.yscale('log')
313
+
314
+ plt.subplot(1, 2, 2)
315
+ plt.plot(history.history['mae'], label='Training MAE')
316
+ plt.plot(history.history['val_mae'], label='Validation MAE')
317
+ plt.xlabel('Epoch')
318
+ plt.ylabel('MAE')
319
+ plt.title('Model MAE')
320
+ plt.legend()
321
+
322
+ plt.tight_layout()
323
+ plt.savefig('improved_training_history.png')
324
+ plt.close()
325
+
326
+ # Save configuration
327
+ config = {
328
+ 'model_path': 'best_gaze_model_improved.keras',
329
+ 'input_shape': list(model.input_shape),
330
+ 'screen_width': int(screen_width),
331
+ 'screen_height': int(screen_height),
332
+ 'test_loss': float(test_loss),
333
+ 'test_mae': float(test_mae),
334
+ 'mean_euclidean_error': float(euclidean_errors.mean()),
335
+ 'preprocessing': {
336
+ 'crop_top': 0.25,
337
+ 'crop_bottom': 0.55,
338
+ 'crop_sides': 0.15
339
+ }
340
+ }
341
+
342
+ with open('model_config_improved.json', 'w') as f:
343
+ json.dump(config, f, indent=2)
344
+
345
+ print(f"\nModel saved to: best_gaze_model_improved.keras")
346
+ print(f"Config saved to: model_config_improved.json")
347
+
348
+ if __name__ == "__main__":
349
+ main()