Olof Astrand commited on
Commit
47bec77
·
1 Parent(s): a29612e

Added web inference option

Browse files
web/Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ libglib2.0-0 \
8
+ libsm6 \
9
+ libxext6 \
10
+ libxrender-dev \
11
+ libgomp1 \
12
+ libglib2.0-0 \
13
+ libgl1-mesa-glx \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy requirements
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application files
23
+ COPY gaze_server.py .
24
+ COPY best_gaze_model.h5 .
25
+
26
+ # Expose port
27
+ EXPOSE 5000
28
+
29
+ # Run the server
30
+ CMD ["python", "gaze_server.py", "--host", "0.0.0.0", "--port", "5000"]
web/gaze_server.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import cv2
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ import base64
7
+ import time
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import logging
11
+
12
+ app = Flask(__name__)
13
+ CORS(app) # Enable CORS for all routes
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class GazeInferenceServer:
20
+ def __init__(self, model_path):
21
+ """Initialize the gaze inference server."""
22
+ self.model_path = model_path
23
+ self.model = None
24
+ self.face_cascade = None
25
+ self.eye_cascade = None
26
+
27
+ # Model parameters
28
+ self.face_size = (224, 224)
29
+ self.eye_size = (80, 60)
30
+
31
+ # Load model and cascades
32
+ self._load_model()
33
+ self._load_cascades()
34
+
35
+ logger.info("Gaze inference server initialized")
36
+
37
+ def _load_model(self):
38
+ """Load the TensorFlow model."""
39
+ try:
40
+ # Define custom objects
41
+ custom_objects = {
42
+ 'euclidean_distance_metric': self._euclidean_distance_metric,
43
+ 'mse': tf.keras.losses.MeanSquaredError(),
44
+ }
45
+
46
+ # Try to load model
47
+ try:
48
+ self.model = tf.keras.models.load_model(
49
+ self.model_path,
50
+ custom_objects=custom_objects
51
+ )
52
+ except:
53
+ # Alternative loading method
54
+ self.model = tf.keras.models.load_model(
55
+ self.model_path,
56
+ compile=False
57
+ )
58
+ self.model.compile(
59
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
60
+ loss='mse',
61
+ metrics=['mae', self._euclidean_distance_metric]
62
+ )
63
+
64
+ logger.info(f"Model loaded successfully from {self.model_path}")
65
+
66
+ except Exception as e:
67
+ logger.error(f"Failed to load model: {e}")
68
+ raise
69
+
70
+ @staticmethod
71
+ def _euclidean_distance_metric(y_true, y_pred):
72
+ """Custom metric for model."""
73
+ return tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1))
74
+
75
+ def _load_cascades(self):
76
+ """Load Haar cascades for face and eye detection."""
77
+ self.face_cascade = cv2.CascadeClassifier(
78
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
79
+ )
80
+ self.eye_cascade = cv2.CascadeClassifier(
81
+ cv2.data.haarcascades + 'haarcascade_eye.xml'
82
+ )
83
+ logger.info("Haar cascades loaded")
84
+
85
+ def extract_eye_regions(self, face_image):
86
+ """Extract left and right eye regions from face image."""
87
+ gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY)
88
+ eyes = self.eye_cascade.detectMultiScale(gray, 1.1, 4)
89
+
90
+ if len(eyes) >= 2:
91
+ # Sort by x-coordinate
92
+ eyes = sorted(eyes, key=lambda e: e[0])
93
+
94
+ # Extract eyes
95
+ lx, ly, lw, lh = eyes[0]
96
+ left_eye = face_image[ly:ly+lh, lx:lx+lw]
97
+ left_eye = cv2.resize(left_eye, self.eye_size)
98
+
99
+ rx, ry, rw, rh = eyes[1]
100
+ right_eye = face_image[ry:ry+rh, rx:rx+rw]
101
+ right_eye = cv2.resize(right_eye, self.eye_size)
102
+
103
+ return left_eye, right_eye, True
104
+ else:
105
+ # Fallback to approximate eye regions
106
+ h, w = face_image.shape[:2]
107
+ left_region = face_image[h//4:h//2, w//4:w//2]
108
+ right_region = face_image[h//4:h//2, w//2:3*w//4]
109
+
110
+ left_eye = cv2.resize(left_region, self.eye_size)
111
+ right_eye = cv2.resize(right_region, self.eye_size)
112
+
113
+ return left_eye, right_eye, False
114
+
115
+ def preprocess_inputs(self, face, left_eye, right_eye):
116
+ """Preprocess images for model input."""
117
+ # Normalize to [0, 1]
118
+ face = face.astype(np.float32) / 255.0
119
+ left_eye = left_eye.astype(np.float32) / 255.0
120
+ right_eye = right_eye.astype(np.float32) / 255.0
121
+
122
+ # Add batch dimension
123
+ face = np.expand_dims(face, axis=0)
124
+ left_eye = np.expand_dims(left_eye, axis=0)
125
+ right_eye = np.expand_dims(right_eye, axis=0)
126
+
127
+ return [face, left_eye, right_eye]
128
+
129
+ def predict_gaze(self, image_data, screen_width, screen_height):
130
+ """Predict gaze position from image."""
131
+ start_time = time.time()
132
+
133
+ try:
134
+ # Decode base64 image
135
+ image_bytes = base64.b64decode(image_data)
136
+ image = Image.open(BytesIO(image_bytes))
137
+ image_np = np.array(image)
138
+
139
+ # Convert RGB to BGR for OpenCV
140
+ if len(image_np.shape) == 3 and image_np.shape[2] == 3:
141
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
142
+
143
+ # Resize face image
144
+ face_resized = cv2.resize(image_np, self.face_size)
145
+
146
+ # Extract eye regions
147
+ left_eye, right_eye, eyes_found = self.extract_eye_regions(face_resized)
148
+
149
+ # Preprocess for model
150
+ inputs = self.preprocess_inputs(face_resized, left_eye, right_eye)
151
+
152
+ # Predict gaze
153
+ gaze_pred = self.model.predict(inputs, verbose=0)[0]
154
+
155
+ print(f"Raw gaze prediction: {gaze_pred}") # Debugging output
156
+
157
+ # Convert to screen coordinates
158
+ gaze_x = float(gaze_pred[0] * screen_width)
159
+ gaze_y = float(gaze_pred[1] * screen_height)
160
+
161
+ # Ensure within bounds
162
+ gaze_x = max(0, min(gaze_x, screen_width))
163
+ gaze_y = max(0, min(gaze_y, screen_height))
164
+
165
+ print(f"Predicted gaze position: ({gaze_x}, {gaze_y})") # Debugging output
166
+
167
+ inference_time = (time.time() - start_time) * 1000 # Convert to ms
168
+
169
+ return {
170
+ 'success': True,
171
+ 'gaze_position': {
172
+ 'x': gaze_x,
173
+ 'y': gaze_y
174
+ },
175
+ 'eyes_found': eyes_found,
176
+ 'inference_time': inference_time
177
+ }
178
+
179
+ except Exception as e:
180
+ logger.error(f"Prediction error: {e}")
181
+ return {
182
+ 'success': False,
183
+ 'error': str(e)
184
+ }
185
+
186
+ # Global server instance
187
+ server = None
188
+
189
+ @app.route('/health', methods=['GET'])
190
+ def health_check():
191
+ """Health check endpoint."""
192
+ return jsonify({
193
+ 'status': 'healthy',
194
+ 'model_loaded': server is not None and server.model is not None
195
+ })
196
+
197
+ @app.route('/predict', methods=['POST'])
198
+ def predict():
199
+ """Predict gaze position from image."""
200
+ try:
201
+ data = request.json
202
+
203
+ if not data or 'image' not in data:
204
+ return jsonify({
205
+ 'success': False,
206
+ 'error': 'No image data provided'
207
+ }), 400
208
+
209
+ # Get parameters
210
+ image_data = data['image']
211
+ screen_width = data.get('screen_width', 1920)
212
+ screen_height = data.get('screen_height', 1080)
213
+
214
+ # Predict gaze
215
+ result = server.predict_gaze(image_data, screen_width, screen_height)
216
+
217
+ return jsonify(result)
218
+
219
+ except Exception as e:
220
+ logger.error(f"Prediction endpoint error: {e}")
221
+ return jsonify({
222
+ 'success': False,
223
+ 'error': str(e)
224
+ }), 500
225
+
226
+ @app.route('/calibrate', methods=['POST'])
227
+ def calibrate():
228
+ """Calibration endpoint (placeholder for future implementation)."""
229
+ return jsonify({
230
+ 'success': True,
231
+ 'message': 'Calibration not yet implemented'
232
+ })
233
+
234
+ def create_app(model_path='best_gaze_model.h5'):
235
+ """Create and configure the Flask app."""
236
+ global server
237
+
238
+ # Initialize server
239
+ server = GazeInferenceServer(model_path)
240
+
241
+ return app
242
+
243
+ if __name__ == '__main__':
244
+ import argparse
245
+ import os
246
+
247
+ # Parse arguments
248
+ parser = argparse.ArgumentParser(description='Gaze Inference Server')
249
+ parser.add_argument(
250
+ '--model',
251
+ type=str,
252
+ default='best_gaze_model.h5',
253
+ help='Path to the trained model'
254
+ )
255
+ parser.add_argument(
256
+ '--port',
257
+ type=int,
258
+ default=5000,
259
+ help='Port to run the server on'
260
+ )
261
+ parser.add_argument(
262
+ '--host',
263
+ type=str,
264
+ default='0.0.0.0',
265
+ help='Host to run the server on'
266
+ )
267
+
268
+ args = parser.parse_args()
269
+
270
+ # Check if model exists
271
+ if not os.path.exists(args.model):
272
+ print(f"Error: Model file '{args.model}' not found!")
273
+ exit(1)
274
+
275
+ # Suppress TensorFlow warnings
276
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
277
+
278
+ # Create app
279
+ app = create_app(args.model)
280
+
281
+ # Run server
282
+ print(f"\n{'='*50}")
283
+ print(f"Starting Gaze Inference Server")
284
+ print(f"Model: {args.model}")
285
+ print(f"Server: http://{args.host}:{args.port}")
286
+ print(f"{'='*50}\n")
287
+
288
+ app.run(
289
+ host=args.host,
290
+ port=args.port,
291
+ debug=False,
292
+ threaded=True
293
+ )
web/gaze_tracking.html ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Gaze Tracking Interface</title>
7
+ <style>
8
+ body {
9
+ margin: 0;
10
+ padding: 0;
11
+ font-family: Arial, sans-serif;
12
+ background-color: #1a1a1a;
13
+ color: white;
14
+ overflow: hidden;
15
+ }
16
+
17
+ #container {
18
+ display: flex;
19
+ height: 100vh;
20
+ }
21
+
22
+ #video-container {
23
+ position: relative;
24
+ width: 320px;
25
+ background-color: #2a2a2a;
26
+ padding: 20px;
27
+ }
28
+
29
+ #video {
30
+ width: 100%;
31
+ height: 240px;
32
+ background-color: #000;
33
+ border: 2px solid #444;
34
+ border-radius: 8px;
35
+ }
36
+
37
+ #canvas {
38
+ display: none;
39
+ }
40
+
41
+ #gaze-screen {
42
+ flex: 1;
43
+ position: relative;
44
+ background-color: #000;
45
+ cursor: none;
46
+ }
47
+
48
+ #gaze-cursor {
49
+ position: absolute;
50
+ width: 40px;
51
+ height: 40px;
52
+ pointer-events: none;
53
+ transition: transform 0.1s ease-out;
54
+ transform: translate(-50%, -50%);
55
+ }
56
+
57
+ .crosshair {
58
+ position: absolute;
59
+ background-color: #00ff00;
60
+ }
61
+
62
+ .crosshair-h {
63
+ width: 40px;
64
+ height: 3px;
65
+ top: 50%;
66
+ left: 0;
67
+ transform: translateY(-50%);
68
+ }
69
+
70
+ .crosshair-v {
71
+ width: 3px;
72
+ height: 40px;
73
+ left: 50%;
74
+ top: 0;
75
+ transform: translateX(-50%);
76
+ }
77
+
78
+ .center-dot {
79
+ position: absolute;
80
+ width: 10px;
81
+ height: 10px;
82
+ background-color: #ff0000;
83
+ border: 2px solid #fff;
84
+ border-radius: 50%;
85
+ top: 50%;
86
+ left: 50%;
87
+ transform: translate(-50%, -50%);
88
+ }
89
+
90
+ #trail {
91
+ position: absolute;
92
+ top: 0;
93
+ left: 0;
94
+ width: 100%;
95
+ height: 100%;
96
+ pointer-events: none;
97
+ }
98
+
99
+ .controls {
100
+ margin-top: 20px;
101
+ }
102
+
103
+ button {
104
+ background-color: #4CAF50;
105
+ border: none;
106
+ color: white;
107
+ padding: 10px 20px;
108
+ margin: 5px;
109
+ cursor: pointer;
110
+ border-radius: 4px;
111
+ font-size: 14px;
112
+ transition: background-color 0.3s;
113
+ }
114
+
115
+ button:hover {
116
+ background-color: #45a049;
117
+ }
118
+
119
+ button:disabled {
120
+ background-color: #666;
121
+ cursor: not-allowed;
122
+ }
123
+
124
+ #status {
125
+ margin-top: 20px;
126
+ padding: 10px;
127
+ background-color: #333;
128
+ border-radius: 4px;
129
+ font-size: 14px;
130
+ }
131
+
132
+ .status-connected {
133
+ color: #4CAF50;
134
+ }
135
+
136
+ .status-disconnected {
137
+ color: #f44336;
138
+ }
139
+
140
+ .info {
141
+ margin-top: 20px;
142
+ font-size: 12px;
143
+ color: #888;
144
+ }
145
+
146
+ #fps {
147
+ position: absolute;
148
+ top: 10px;
149
+ left: 10px;
150
+ background-color: rgba(0, 0, 0, 0.7);
151
+ padding: 5px 10px;
152
+ border-radius: 4px;
153
+ font-size: 14px;
154
+ }
155
+
156
+ #coordinates {
157
+ position: absolute;
158
+ top: 40px;
159
+ left: 10px;
160
+ background-color: rgba(0, 0, 0, 0.7);
161
+ padding: 5px 10px;
162
+ border-radius: 4px;
163
+ font-size: 14px;
164
+ }
165
+
166
+ .face-box {
167
+ position: absolute;
168
+ border: 2px solid #00ff00;
169
+ pointer-events: none;
170
+ }
171
+
172
+ .eye-box {
173
+ position: absolute;
174
+ border: 2px solid #ffff00;
175
+ pointer-events: none;
176
+ }
177
+
178
+ #smoothing-slider {
179
+ width: 100%;
180
+ margin-top: 10px;
181
+ }
182
+
183
+ .slider-container {
184
+ margin-top: 20px;
185
+ }
186
+
187
+ .slider-label {
188
+ font-size: 12px;
189
+ color: #888;
190
+ margin-bottom: 5px;
191
+ }
192
+ </style>
193
+ </head>
194
+ <body>
195
+ <div id="container">
196
+ <div id="video-container">
197
+ <video id="video" autoplay></video>
198
+ <canvas id="canvas"></canvas>
199
+
200
+ <div class="controls">
201
+ <button id="startBtn">Start Tracking</button>
202
+ <button id="stopBtn" disabled>Stop Tracking</button>
203
+ <button id="calibrateBtn">Calibrate</button>
204
+ </div>
205
+
206
+ <div id="status" class="status-disconnected">
207
+ Status: Not connected
208
+ </div>
209
+
210
+ <div class="slider-container">
211
+ <div class="slider-label">Smoothing: <span id="smoothing-value">5</span></div>
212
+ <input type="range" id="smoothing-slider" min="1" max="20" value="5">
213
+ </div>
214
+
215
+ <div class="info">
216
+ <p>Face Detection: <span id="face-status">Not detected</span></p>
217
+ <p>Model Inference: <span id="inference-time">0</span> ms</p>
218
+ <p>Server: <span id="server-url">http://localhost:5000</span></p>
219
+ </div>
220
+ </div>
221
+
222
+ <div id="gaze-screen">
223
+ <canvas id="trail"></canvas>
224
+ <div id="gaze-cursor">
225
+ <div class="crosshair crosshair-h"></div>
226
+ <div class="crosshair crosshair-v"></div>
227
+ <div class="center-dot"></div>
228
+ </div>
229
+ <div id="fps">FPS: 0</div>
230
+ <div id="coordinates">X: 0, Y: 0</div>
231
+ </div>
232
+ </div>
233
+
234
+ <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
235
+ <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/blazeface"></script>
236
+ <script>
237
+ class GazeTracker {
238
+ constructor() {
239
+ this.video = document.getElementById('video');
240
+ this.canvas = document.getElementById('canvas');
241
+ this.ctx = this.canvas.getContext('2d');
242
+ this.trailCanvas = document.getElementById('trail');
243
+ this.trailCtx = this.trailCanvas.getContext('2d');
244
+
245
+ this.gazeCursor = document.getElementById('gaze-cursor');
246
+ this.startBtn = document.getElementById('startBtn');
247
+ this.stopBtn = document.getElementById('stopBtn');
248
+ this.calibrateBtn = document.getElementById('calibrateBtn');
249
+ this.smoothingSlider = document.getElementById('smoothing-slider');
250
+
251
+ this.isTracking = false;
252
+ this.faceModel = null;
253
+ this.serverUrl = 'http://localhost:5000';
254
+
255
+ // Gaze position and smoothing
256
+ this.currentGaze = { x: window.innerWidth / 2, y: window.innerHeight / 2 };
257
+ this.gazeHistory = [];
258
+ this.smoothingWindow = 5;
259
+
260
+ // Initialize Kalman filter after DOM is ready
261
+ this.kalmanFilter = null;
262
+
263
+ // Trail points
264
+ this.trailPoints = [];
265
+ this.maxTrailLength = 30;
266
+
267
+ // Performance tracking
268
+ this.lastTime = performance.now();
269
+ this.frameCount = 0;
270
+ this.fps = 0;
271
+
272
+ this.setupEventListeners();
273
+ this.resizeTrailCanvas();
274
+ window.addEventListener('resize', () => this.resizeTrailCanvas());
275
+
276
+ // Initialize Kalman filter after a short delay to ensure DOM is ready
277
+ setTimeout(() => {
278
+ this.kalmanFilter = this.initKalmanFilter();
279
+ }, 100);
280
+ }
281
+
282
+ initKalmanFilter() {
283
+ // Get initial screen dimensions
284
+ const gazeScreen = document.getElementById('gaze-screen');
285
+ const initialX = gazeScreen ? gazeScreen.offsetWidth / 2 : window.innerWidth / 2;
286
+ const initialY = gazeScreen ? gazeScreen.offsetHeight / 2 : window.innerHeight / 2;
287
+
288
+ return {
289
+ x: { estimate: initialX, uncertainty: 1000 },
290
+ y: { estimate: initialY, uncertainty: 1000 },
291
+ processNoise: 1,
292
+ measurementNoise: 25
293
+ };
294
+ }
295
+
296
+ kalmanUpdate(axis, measurement) {
297
+ const filter = this.kalmanFilter[axis];
298
+
299
+ // Check for valid measurement
300
+ if (isNaN(measurement) || !isFinite(measurement)) {
301
+ console.warn(`Invalid measurement for ${axis}: ${measurement}`);
302
+ return filter.estimate;
303
+ }
304
+
305
+ // Predict
306
+ filter.uncertainty += filter.processNoise;
307
+
308
+ // Update
309
+ const gain = filter.uncertainty / (filter.uncertainty + filter.measurementNoise);
310
+ filter.estimate = filter.estimate + gain * (measurement - filter.estimate);
311
+ filter.uncertainty = (1 - gain) * filter.uncertainty;
312
+
313
+ // Check for NaN
314
+ if (isNaN(filter.estimate) || !isFinite(filter.estimate)) {
315
+ console.warn(`Kalman filter produced NaN for ${axis}, resetting...`);
316
+ // Reset to measurement
317
+ filter.estimate = measurement;
318
+ filter.uncertainty = 1000;
319
+ }
320
+
321
+ return filter.estimate;
322
+ }
323
+
324
+ resizeTrailCanvas() {
325
+ const gazeScreen = document.getElementById('gaze-screen');
326
+ this.trailCanvas.width = gazeScreen.offsetWidth;
327
+ this.trailCanvas.height = gazeScreen.offsetHeight;
328
+ }
329
+
330
+ setupEventListeners() {
331
+ this.startBtn.addEventListener('click', () => this.start());
332
+ this.stopBtn.addEventListener('click', () => this.stop());
333
+ this.calibrateBtn.addEventListener('click', () => this.calibrate());
334
+
335
+ // Add keyboard shortcut for testing
336
+ document.addEventListener('keypress', (e) => {
337
+ if (e.key === 't' || e.key === 'T') {
338
+ // Test cursor movement
339
+ console.log('Testing cursor movement...');
340
+ const testX = Math.random() * window.innerWidth;
341
+ const testY = Math.random() * window.innerHeight;
342
+ this.updateGazePosition({ x: testX, y: testY });
343
+ } else if (e.key === 'k' || e.key === 'K') {
344
+ // Toggle Kalman filter
345
+ if (this.kalmanFilter) {
346
+ this.kalmanFilter = null;
347
+ console.log('Kalman filter disabled');
348
+ alert('Kalman filter disabled - using simple averaging only');
349
+ } else {
350
+ this.kalmanFilter = this.initKalmanFilter();
351
+ console.log('Kalman filter enabled');
352
+ alert('Kalman filter enabled');
353
+ }
354
+ }
355
+ });
356
+
357
+ this.smoothingSlider.addEventListener('input', (e) => {
358
+ this.smoothingWindow = parseInt(e.target.value);
359
+ document.getElementById('smoothing-value').textContent = this.smoothingWindow;
360
+ this.gazeHistory = [];
361
+ });
362
+ }
363
+
364
+ async start() {
365
+ try {
366
+ // Get camera stream
367
+ const stream = await navigator.mediaDevices.getUserMedia({
368
+ video: { width: 640, height: 480 }
369
+ });
370
+ this.video.srcObject = stream;
371
+
372
+ // Wait for video to load
373
+ await new Promise(resolve => {
374
+ this.video.onloadedmetadata = resolve;
375
+ });
376
+
377
+ // Set canvas size
378
+ this.canvas.width = this.video.videoWidth;
379
+ this.canvas.height = this.video.videoHeight;
380
+
381
+ // Load face detection model
382
+ if (!this.faceModel) {
383
+ this.updateStatus('Loading face detection model...', false);
384
+ this.faceModel = await blazeface.load();
385
+ }
386
+
387
+ // Check server connection
388
+ await this.checkServerConnection();
389
+
390
+ this.isTracking = true;
391
+ this.startBtn.disabled = true;
392
+ this.stopBtn.disabled = false;
393
+
394
+ this.updateStatus('Tracking active', true);
395
+ this.trackGaze();
396
+
397
+ } catch (error) {
398
+ console.error('Error starting tracking:', error);
399
+ this.updateStatus('Error: ' + error.message, false);
400
+ }
401
+ }
402
+
403
+ stop() {
404
+ this.isTracking = false;
405
+
406
+ if (this.video.srcObject) {
407
+ this.video.srcObject.getTracks().forEach(track => track.stop());
408
+ }
409
+
410
+ this.startBtn.disabled = false;
411
+ this.stopBtn.disabled = true;
412
+ this.updateStatus('Tracking stopped', false);
413
+ }
414
+
415
+ async checkServerConnection() {
416
+ try {
417
+ const response = await fetch(`${this.serverUrl}/health`);
418
+ if (!response.ok) throw new Error('Server not responding');
419
+ return true;
420
+ } catch (error) {
421
+ throw new Error('Cannot connect to inference server. Make sure the Python server is running.');
422
+ }
423
+ }
424
+
425
+ async trackGaze() {
426
+ if (!this.isTracking) return;
427
+
428
+ const startTime = performance.now();
429
+
430
+ // Capture frame
431
+ this.ctx.drawImage(this.video, 0, 0);
432
+
433
+ // Detect faces
434
+ const predictions = await this.faceModel.estimateFaces(
435
+ this.canvas,
436
+ false // Don't flip horizontally
437
+ );
438
+
439
+ if (predictions.length > 0) {
440
+ const face = predictions[0];
441
+
442
+ // Update face status
443
+ document.getElementById('face-status').textContent = 'Detected';
444
+
445
+ // Extract face region
446
+ const [x1, y1] = face.topLeft;
447
+ const [x2, y2] = face.bottomRight;
448
+ const width = x2 - x1;
449
+ const height = y2 - y1;
450
+
451
+ // Add padding
452
+ const padding = Math.max(width, height) * 0.2;
453
+ const faceX = Math.max(0, x1 - padding);
454
+ const faceY = Math.max(0, y1 - padding);
455
+ const faceWidth = Math.min(this.canvas.width - faceX, width + 2 * padding);
456
+ const faceHeight = Math.min(this.canvas.height - faceY, height + 2 * padding);
457
+
458
+ // Get face image data
459
+ const faceImageData = this.ctx.getImageData(faceX, faceY, faceWidth, faceHeight);
460
+
461
+ // Send to server for inference
462
+ const gazePosition = await this.sendToServer(faceImageData, {
463
+ x: faceX,
464
+ y: faceY,
465
+ width: faceWidth,
466
+ height: faceHeight
467
+ });
468
+
469
+ if (gazePosition) {
470
+ this.updateGazePosition(gazePosition);
471
+ }
472
+
473
+ } else {
474
+ document.getElementById('face-status').textContent = 'Not detected';
475
+ }
476
+
477
+ // Update performance metrics
478
+ this.updatePerformanceMetrics(startTime);
479
+
480
+ // Continue tracking
481
+ requestAnimationFrame(() => this.trackGaze());
482
+ }
483
+
484
+ async sendToServer(imageData, faceRect) {
485
+ try {
486
+ // Convert ImageData to base64
487
+ const tempCanvas = document.createElement('canvas');
488
+ tempCanvas.width = imageData.width;
489
+ tempCanvas.height = imageData.height;
490
+ const tempCtx = tempCanvas.getContext('2d');
491
+ tempCtx.putImageData(imageData, 0, 0);
492
+
493
+ const base64Image = tempCanvas.toDataURL('image/jpeg', 0.8).split(',')[1];
494
+
495
+ // Get actual screen dimensions
496
+ const gazeScreen = document.getElementById('gaze-screen');
497
+ const screenWidth = gazeScreen.offsetWidth;
498
+ const screenHeight = gazeScreen.offsetHeight;
499
+
500
+ console.log('Sending screen dimensions:', { screenWidth, screenHeight });
501
+
502
+ const response = await fetch(`${this.serverUrl}/predict`, {
503
+ method: 'POST',
504
+ headers: {
505
+ 'Content-Type': 'application/json',
506
+ },
507
+ body: JSON.stringify({
508
+ image: base64Image,
509
+ face_rect: faceRect,
510
+ screen_width: screenWidth,
511
+ screen_height: screenHeight
512
+ })
513
+ });
514
+
515
+ if (!response.ok) throw new Error('Server error');
516
+
517
+ const data = await response.json();
518
+
519
+ console.log('Received gaze position:', data.gaze_position);
520
+
521
+ // Update inference time
522
+ document.getElementById('inference-time').textContent =
523
+ data.inference_time ? data.inference_time.toFixed(1) : '0';
524
+
525
+ return data.gaze_position;
526
+
527
+ } catch (error) {
528
+ console.error('Error sending to server:', error);
529
+ return null;
530
+ }
531
+ }
532
+
533
+ updateGazePosition(position) {
534
+ // Validate input
535
+ if (!position || isNaN(position.x) || isNaN(position.y)) {
536
+ console.error('Invalid position received:', position);
537
+ return;
538
+ }
539
+
540
+ // Add to history
541
+ this.gazeHistory.push(position);
542
+ if (this.gazeHistory.length > this.smoothingWindow) {
543
+ this.gazeHistory.shift();
544
+ }
545
+
546
+ // Calculate smoothed position
547
+ let smoothedX, smoothedY;
548
+
549
+ if (this.gazeHistory.length > 0) {
550
+ // Moving average
551
+ const avgX = this.gazeHistory.reduce((sum, p) => sum + p.x, 0) / this.gazeHistory.length;
552
+ const avgY = this.gazeHistory.reduce((sum, p) => sum + p.y, 0) / this.gazeHistory.length;
553
+
554
+ // Try Kalman filter if initialized, otherwise use average
555
+ if (this.kalmanFilter) {
556
+ smoothedX = this.kalmanUpdate('x', avgX);
557
+ smoothedY = this.kalmanUpdate('y', avgY);
558
+
559
+ // Fallback if Kalman produces NaN
560
+ if (isNaN(smoothedX) || isNaN(smoothedY)) {
561
+ console.warn('Kalman filter failed, using average');
562
+ smoothedX = avgX;
563
+ smoothedY = avgY;
564
+ }
565
+ } else {
566
+ smoothedX = avgX;
567
+ smoothedY = avgY;
568
+ }
569
+ } else {
570
+ smoothedX = position.x;
571
+ smoothedY = position.y;
572
+ }
573
+
574
+ // Ensure coordinates are within screen bounds
575
+ const gazeScreen = document.getElementById('gaze-screen');
576
+ smoothedX = Math.max(0, Math.min(smoothedX, gazeScreen.offsetWidth));
577
+ smoothedY = Math.max(0, Math.min(smoothedY, gazeScreen.offsetHeight));
578
+
579
+ console.log('Updating gaze position:', {
580
+ raw: position,
581
+ smoothed: { x: smoothedX, y: smoothedY },
582
+ screenBounds: {
583
+ width: gazeScreen.offsetWidth,
584
+ height: gazeScreen.offsetHeight
585
+ }
586
+ });
587
+
588
+ // Update cursor position
589
+ this.currentGaze = { x: smoothedX, y: smoothedY };
590
+ this.gazeCursor.style.left = `${smoothedX}px`;
591
+ this.gazeCursor.style.top = `${smoothedY}px`;
592
+
593
+ // Update coordinates display
594
+ document.getElementById('coordinates').textContent =
595
+ `X: ${Math.round(smoothedX)}, Y: ${Math.round(smoothedY)}`;
596
+
597
+ // Update trail
598
+ this.updateTrail(smoothedX, smoothedY);
599
+ }
600
+
601
+ updateTrail(x, y) {
602
+ this.trailPoints.push({ x, y, time: Date.now() });
603
+
604
+ // Remove old points
605
+ if (this.trailPoints.length > this.maxTrailLength) {
606
+ this.trailPoints.shift();
607
+ }
608
+
609
+ // Clear and redraw trail
610
+ this.trailCtx.clearRect(0, 0, this.trailCanvas.width, this.trailCanvas.height);
611
+
612
+ if (this.trailPoints.length > 1) {
613
+ this.trailCtx.beginPath();
614
+ this.trailCtx.moveTo(this.trailPoints[0].x, this.trailPoints[0].y);
615
+
616
+ for (let i = 1; i < this.trailPoints.length; i++) {
617
+ const point = this.trailPoints[i];
618
+ const prevPoint = this.trailPoints[i - 1];
619
+
620
+ // Gradient effect
621
+ const alpha = i / this.trailPoints.length;
622
+ this.trailCtx.strokeStyle = `rgba(0, 255, 0, ${alpha * 0.5})`;
623
+ this.trailCtx.lineWidth = 2;
624
+
625
+ this.trailCtx.beginPath();
626
+ this.trailCtx.moveTo(prevPoint.x, prevPoint.y);
627
+ this.trailCtx.lineTo(point.x, point.y);
628
+ this.trailCtx.stroke();
629
+ }
630
+ }
631
+ }
632
+
633
+ updatePerformanceMetrics(startTime) {
634
+ const endTime = performance.now();
635
+ const frameTime = endTime - startTime;
636
+
637
+ this.frameCount++;
638
+ if (endTime - this.lastTime >= 1000) {
639
+ this.fps = this.frameCount;
640
+ this.frameCount = 0;
641
+ this.lastTime = endTime;
642
+
643
+ document.getElementById('fps').textContent = `FPS: ${this.fps}`;
644
+ }
645
+ }
646
+
647
+ updateStatus(message, isConnected) {
648
+ const statusEl = document.getElementById('status');
649
+ statusEl.textContent = `Status: ${message}`;
650
+ statusEl.className = isConnected ? 'status-connected' : 'status-disconnected';
651
+ }
652
+
653
+ async calibrate() {
654
+ // Implement calibration logic
655
+ alert('Calibration feature coming soon!');
656
+ }
657
+ }
658
+
659
+ // Initialize tracker when page loads
660
+ let tracker;
661
+ window.addEventListener('DOMContentLoaded', () => {
662
+ tracker = new GazeTracker();
663
+ });
664
+ </script>
665
+ </body>
666
+ </html>
web/readme.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gaze Tracking Web Interface
2
+
3
+ This system provides a web-based interface for real-time gaze tracking using your trained TensorFlow model. It uses the browser's webcam for face detection and communicates with a Python Flask server for gaze inference.
4
+
5
+ ## Components
6
+
7
+ 1. **HTML Interface** (`gaze_tracking.html`): Web-based UI with webcam capture and gaze visualization
8
+ 2. **Flask Server** (`gaze_server.py`): Python backend that runs your TensorFlow model
9
+ 3. **Face Detection**: Uses TensorFlow.js BlazeFace in the browser + OpenCV Haar cascades on the server
10
+
11
+ ## Features
12
+
13
+ - Real-time face detection in the browser
14
+ - Smooth gaze tracking with Kalman filtering
15
+ - Visual gaze trail
16
+ - FPS and performance monitoring
17
+ - Adjustable smoothing parameters
18
+ - Full-screen gaze visualization
19
+
20
+ ## Setup Instructions
21
+
22
+ ### 1. Install Python Dependencies
23
+
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ ### 2. Start the Flask Server
29
+
30
+ ```bash
31
+ python gaze_server.py --model best_gaze_model.h5 --port 5000
32
+ ```
33
+
34
+ Options:
35
+ - `--model`: Path to your trained model (default: `best_gaze_model.h5`)
36
+ - `--port`: Server port (default: 5000)
37
+ - `--host`: Server host (default: 0.0.0.0)
38
+
39
+ ### 3. Open the HTML Interface
40
+
41
+ 1. Open `gaze_tracking.html` in a modern web browser (Chrome/Firefox/Edge)
42
+ 2. Allow camera access when prompted
43
+ 3. Click "Start Tracking" to begin
44
+
45
+ ## How It Works
46
+
47
+ 1. **Face Detection**: The browser uses BlazeFace (TensorFlow.js) to detect faces in real-time
48
+ 2. **Face Extraction**: When a face is detected, the face region is extracted and sent to the server
49
+ 3. **Eye Detection**: The server uses OpenCV to detect eye regions within the face
50
+ 4. **Model Inference**: Your trained model processes the face and eye images to predict gaze coordinates
51
+ 5. **Smoothing**: The browser applies moving average and Kalman filtering for smooth cursor movement
52
+ 6. **Visualization**: The gaze position is displayed as a crosshair with a trail effect
53
+
54
+ ## Architecture
55
+
56
+ ```
57
+ Browser (Client) Python Server
58
+ ┌─────────────────┐ ┌──────────────────┐
59
+ │ │ │ │
60
+ │ Webcam Feed │ │ TensorFlow │
61
+ │ ↓ │ │ Gaze Model │
62
+ │ Face Detection │ HTTP POST │ ↑ │
63
+ │ (BlazeFace) │ →→→→→→→→→→→→ │ Face & Eyes │
64
+ │ ↓ │ (Base64 img) │ Processing │
65
+ │ Send Face ROI │ │ ↓ │
66
+ │ ↓ │ ←←←←←←←←←←←← │ Gaze Position │
67
+ │ Smoothing & │ (JSON resp) │ Prediction │
68
+ │ Visualization │ │ │
69
+ │ │ │ │
70
+ └─────────────────┘ └──────────────────┘
71
+ ```
72
+
73
+ ## Controls
74
+
75
+ - **Start/Stop Tracking**: Control gaze tracking
76
+ - **Smoothing Slider**: Adjust smoothing window (1-20 frames)
77
+ - **Calibrate**: (Coming soon) Calibration for improved accuracy
78
+
79
+ ## Performance Tips
80
+
81
+ 1. **Lighting**: Ensure good, even lighting on your face
82
+ 2. **Position**: Sit at a comfortable distance from the camera
83
+ 3. **Stability**: Keep your head relatively stable for best results
84
+ 4. **Browser**: Use Chrome or Firefox for best performance
85
+
86
+ ## Troubleshooting
87
+
88
+ ### Server Won't Start
89
+ - Check if the model file exists at the specified path
90
+ - Ensure all Python dependencies are installed
91
+ - Check if port 5000 is available
92
+
93
+ ### No Face Detection
94
+ - Ensure adequate lighting
95
+ - Check camera permissions in browser
96
+ - Try adjusting your distance from the camera
97
+
98
+ ### Poor Tracking Accuracy
99
+ - The model may need calibration for your specific setup
100
+ - Try adjusting the smoothing parameter
101
+ - Ensure eyes are clearly visible to the camera
102
+
103
+ ## API Endpoints
104
+
105
+ - `GET /health`: Health check
106
+ - `POST /predict`: Gaze prediction endpoint
107
+ - Request: `{ image: base64, screen_width: int, screen_height: int }`
108
+ - Response: `{ gaze_position: {x, y}, inference_time: float }`
109
+
110
+ ## Future Enhancements
111
+
112
+ - User-specific calibration system
113
+ - Multi-face tracking support
114
+ - Gaze heatmap visualization
115
+ - Recording and playback features
116
+ - WebSocket support for lower latency
117
+
118
+ ## Security Notes
119
+
120
+ - The server runs locally by default
121
+ - For remote access, consider adding authentication
122
+ - Use HTTPS in production environments
web/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Minimal requirements for Python 3.12 compatibility
2
+ flask>=3.0.0
3
+ flask-cors>=4.0.0
4
+ tensorflow>=2.15.0
5
+ opencv-python>=4.9.0.80
6
+ numpy>=1.26.2
7
+ pillow>=10.1.0