from flask import Flask, request, jsonify from flask_cors import CORS import cv2 import numpy as np import tensorflow as tf import base64 import time from io import BytesIO from PIL import Image import logging app = Flask(__name__) CORS(app) # Enable CORS for all routes # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GazeInferenceServer: def __init__(self, model_path): """Initialize the gaze inference server.""" self.model_path = model_path self.model = None self.face_cascade = None self.eye_cascade = None # Model parameters self.face_size = (224, 224) self.eye_size = (80, 60) # Load model and cascades self._load_model() self._load_cascades() logger.info("Gaze inference server initialized") def _load_model(self): """Load the TensorFlow model.""" try: # Define custom objects custom_objects = { 'euclidean_distance_metric': self._euclidean_distance_metric, 'mse': tf.keras.losses.MeanSquaredError(), } # Try to load model try: self.model = tf.keras.models.load_model( self.model_path, custom_objects=custom_objects ) except: # Alternative loading method self.model = tf.keras.models.load_model( self.model_path, compile=False ) self.model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse', metrics=['mae', self._euclidean_distance_metric] ) logger.info(f"Model loaded successfully from {self.model_path}") except Exception as e: logger.error(f"Failed to load model: {e}") raise @staticmethod def _euclidean_distance_metric(y_true, y_pred): """Custom metric for model.""" return tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1)) def _load_cascades(self): """Load Haar cascades for face and eye detection.""" self.face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' ) self.eye_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + 'haarcascade_eye.xml' ) logger.info("Haar cascades loaded") def extract_eye_regions(self, face_image): """Extract left and right eye regions from face image.""" gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY) eyes = self.eye_cascade.detectMultiScale(gray, 1.1, 4) if len(eyes) >= 2: # Sort by x-coordinate eyes = sorted(eyes, key=lambda e: e[0]) # Extract eyes lx, ly, lw, lh = eyes[0] left_eye = face_image[ly:ly+lh, lx:lx+lw] left_eye = cv2.resize(left_eye, self.eye_size) rx, ry, rw, rh = eyes[1] right_eye = face_image[ry:ry+rh, rx:rx+rw] right_eye = cv2.resize(right_eye, self.eye_size) return left_eye, right_eye, True else: # Fallback to approximate eye regions h, w = face_image.shape[:2] left_region = face_image[h//4:h//2, w//4:w//2] right_region = face_image[h//4:h//2, w//2:3*w//4] left_eye = cv2.resize(left_region, self.eye_size) right_eye = cv2.resize(right_region, self.eye_size) return left_eye, right_eye, False def preprocess_inputs(self, face, left_eye, right_eye): """Preprocess images for model input.""" # Normalize to [0, 1] face = face.astype(np.float32) / 255.0 left_eye = left_eye.astype(np.float32) / 255.0 right_eye = right_eye.astype(np.float32) / 255.0 # Add batch dimension face = np.expand_dims(face, axis=0) left_eye = np.expand_dims(left_eye, axis=0) right_eye = np.expand_dims(right_eye, axis=0) return [face, left_eye, right_eye] def predict_gaze(self, image_data, screen_width, screen_height): """Predict gaze position from image.""" start_time = time.time() try: # Decode base64 image image_bytes = base64.b64decode(image_data) image = Image.open(BytesIO(image_bytes)) image_np = np.array(image) # Convert RGB to BGR for OpenCV if len(image_np.shape) == 3 and image_np.shape[2] == 3: image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # Resize face image face_resized = cv2.resize(image_np, self.face_size) # Extract eye regions left_eye, right_eye, eyes_found = self.extract_eye_regions(face_resized) # Preprocess for model inputs = self.preprocess_inputs(face_resized, left_eye, right_eye) # Predict gaze gaze_pred = self.model.predict(inputs, verbose=0)[0] print(f"Raw gaze prediction: {gaze_pred}") # Debugging output # Convert to screen coordinates gaze_x = float(gaze_pred[0] * screen_width) gaze_y = float(gaze_pred[1] * screen_height) # Ensure within bounds gaze_x = max(0, min(gaze_x, screen_width)) gaze_y = max(0, min(gaze_y, screen_height)) print(f"Predicted gaze position: ({gaze_x}, {gaze_y})") # Debugging output inference_time = (time.time() - start_time) * 1000 # Convert to ms return { 'success': True, 'gaze_position': { 'x': gaze_x, 'y': gaze_y }, 'eyes_found': eyes_found, 'inference_time': inference_time } except Exception as e: logger.error(f"Prediction error: {e}") return { 'success': False, 'error': str(e) } # Global server instance server = None @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint.""" return jsonify({ 'status': 'healthy', 'model_loaded': server is not None and server.model is not None }) @app.route('/predict', methods=['POST']) def predict(): """Predict gaze position from image.""" try: data = request.json if not data or 'image' not in data: return jsonify({ 'success': False, 'error': 'No image data provided' }), 400 # Get parameters image_data = data['image'] screen_width = data.get('screen_width', 1920) screen_height = data.get('screen_height', 1080) # Predict gaze result = server.predict_gaze(image_data, screen_width, screen_height) return jsonify(result) except Exception as e: logger.error(f"Prediction endpoint error: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/calibrate', methods=['POST']) def calibrate(): """Calibration endpoint (placeholder for future implementation).""" return jsonify({ 'success': True, 'message': 'Calibration not yet implemented' }) def create_app(model_path='best_gaze_model.h5'): """Create and configure the Flask app.""" global server # Initialize server server = GazeInferenceServer(model_path) return app if __name__ == '__main__': import argparse import os # Parse arguments parser = argparse.ArgumentParser(description='Gaze Inference Server') parser.add_argument( '--model', type=str, default='best_gaze_model.h5', help='Path to the trained model' ) parser.add_argument( '--port', type=int, default=5000, help='Port to run the server on' ) parser.add_argument( '--host', type=str, default='0.0.0.0', help='Host to run the server on' ) args = parser.parse_args() # Check if model exists if not os.path.exists(args.model): print(f"Error: Model file '{args.model}' not found!") exit(1) # Suppress TensorFlow warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Create app app = create_app(args.model) # Run server print(f"\n{'='*50}") print(f"Starting Gaze Inference Server") print(f"Model: {args.model}") print(f"Server: http://{args.host}:{args.port}") print(f"{'='*50}\n") app.run( host=args.host, port=args.port, debug=False, threaded=True )