gaze_test / web /gaze_server.py
Olof Astrand
Added web inference option
47bec77
from flask import Flask, request, jsonify
from flask_cors import CORS
import cv2
import numpy as np
import tensorflow as tf
import base64
import time
from io import BytesIO
from PIL import Image
import logging
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GazeInferenceServer:
def __init__(self, model_path):
"""Initialize the gaze inference server."""
self.model_path = model_path
self.model = None
self.face_cascade = None
self.eye_cascade = None
# Model parameters
self.face_size = (224, 224)
self.eye_size = (80, 60)
# Load model and cascades
self._load_model()
self._load_cascades()
logger.info("Gaze inference server initialized")
def _load_model(self):
"""Load the TensorFlow model."""
try:
# Define custom objects
custom_objects = {
'euclidean_distance_metric': self._euclidean_distance_metric,
'mse': tf.keras.losses.MeanSquaredError(),
}
# Try to load model
try:
self.model = tf.keras.models.load_model(
self.model_path,
custom_objects=custom_objects
)
except:
# Alternative loading method
self.model = tf.keras.models.load_model(
self.model_path,
compile=False
)
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mse',
metrics=['mae', self._euclidean_distance_metric]
)
logger.info(f"Model loaded successfully from {self.model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@staticmethod
def _euclidean_distance_metric(y_true, y_pred):
"""Custom metric for model."""
return tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1))
def _load_cascades(self):
"""Load Haar cascades for face and eye detection."""
self.face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
)
self.eye_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + 'haarcascade_eye.xml'
)
logger.info("Haar cascades loaded")
def extract_eye_regions(self, face_image):
"""Extract left and right eye regions from face image."""
gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY)
eyes = self.eye_cascade.detectMultiScale(gray, 1.1, 4)
if len(eyes) >= 2:
# Sort by x-coordinate
eyes = sorted(eyes, key=lambda e: e[0])
# Extract eyes
lx, ly, lw, lh = eyes[0]
left_eye = face_image[ly:ly+lh, lx:lx+lw]
left_eye = cv2.resize(left_eye, self.eye_size)
rx, ry, rw, rh = eyes[1]
right_eye = face_image[ry:ry+rh, rx:rx+rw]
right_eye = cv2.resize(right_eye, self.eye_size)
return left_eye, right_eye, True
else:
# Fallback to approximate eye regions
h, w = face_image.shape[:2]
left_region = face_image[h//4:h//2, w//4:w//2]
right_region = face_image[h//4:h//2, w//2:3*w//4]
left_eye = cv2.resize(left_region, self.eye_size)
right_eye = cv2.resize(right_region, self.eye_size)
return left_eye, right_eye, False
def preprocess_inputs(self, face, left_eye, right_eye):
"""Preprocess images for model input."""
# Normalize to [0, 1]
face = face.astype(np.float32) / 255.0
left_eye = left_eye.astype(np.float32) / 255.0
right_eye = right_eye.astype(np.float32) / 255.0
# Add batch dimension
face = np.expand_dims(face, axis=0)
left_eye = np.expand_dims(left_eye, axis=0)
right_eye = np.expand_dims(right_eye, axis=0)
return [face, left_eye, right_eye]
def predict_gaze(self, image_data, screen_width, screen_height):
"""Predict gaze position from image."""
start_time = time.time()
try:
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
image_np = np.array(image)
# Convert RGB to BGR for OpenCV
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# Resize face image
face_resized = cv2.resize(image_np, self.face_size)
# Extract eye regions
left_eye, right_eye, eyes_found = self.extract_eye_regions(face_resized)
# Preprocess for model
inputs = self.preprocess_inputs(face_resized, left_eye, right_eye)
# Predict gaze
gaze_pred = self.model.predict(inputs, verbose=0)[0]
print(f"Raw gaze prediction: {gaze_pred}") # Debugging output
# Convert to screen coordinates
gaze_x = float(gaze_pred[0] * screen_width)
gaze_y = float(gaze_pred[1] * screen_height)
# Ensure within bounds
gaze_x = max(0, min(gaze_x, screen_width))
gaze_y = max(0, min(gaze_y, screen_height))
print(f"Predicted gaze position: ({gaze_x}, {gaze_y})") # Debugging output
inference_time = (time.time() - start_time) * 1000 # Convert to ms
return {
'success': True,
'gaze_position': {
'x': gaze_x,
'y': gaze_y
},
'eyes_found': eyes_found,
'inference_time': inference_time
}
except Exception as e:
logger.error(f"Prediction error: {e}")
return {
'success': False,
'error': str(e)
}
# Global server instance
server = None
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint."""
return jsonify({
'status': 'healthy',
'model_loaded': server is not None and server.model is not None
})
@app.route('/predict', methods=['POST'])
def predict():
"""Predict gaze position from image."""
try:
data = request.json
if not data or 'image' not in data:
return jsonify({
'success': False,
'error': 'No image data provided'
}), 400
# Get parameters
image_data = data['image']
screen_width = data.get('screen_width', 1920)
screen_height = data.get('screen_height', 1080)
# Predict gaze
result = server.predict_gaze(image_data, screen_width, screen_height)
return jsonify(result)
except Exception as e:
logger.error(f"Prediction endpoint error: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/calibrate', methods=['POST'])
def calibrate():
"""Calibration endpoint (placeholder for future implementation)."""
return jsonify({
'success': True,
'message': 'Calibration not yet implemented'
})
def create_app(model_path='best_gaze_model.h5'):
"""Create and configure the Flask app."""
global server
# Initialize server
server = GazeInferenceServer(model_path)
return app
if __name__ == '__main__':
import argparse
import os
# Parse arguments
parser = argparse.ArgumentParser(description='Gaze Inference Server')
parser.add_argument(
'--model',
type=str,
default='best_gaze_model.h5',
help='Path to the trained model'
)
parser.add_argument(
'--port',
type=int,
default=5000,
help='Port to run the server on'
)
parser.add_argument(
'--host',
type=str,
default='0.0.0.0',
help='Host to run the server on'
)
args = parser.parse_args()
# Check if model exists
if not os.path.exists(args.model):
print(f"Error: Model file '{args.model}' not found!")
exit(1)
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Create app
app = create_app(args.model)
# Run server
print(f"\n{'='*50}")
print(f"Starting Gaze Inference Server")
print(f"Model: {args.model}")
print(f"Server: http://{args.host}:{args.port}")
print(f"{'='*50}\n")
app.run(
host=args.host,
port=args.port,
debug=False,
threaded=True
)