| | from flask import Flask, request, jsonify |
| | from flask_cors import CORS |
| | import cv2 |
| | import numpy as np |
| | import tensorflow as tf |
| | import base64 |
| | import time |
| | from io import BytesIO |
| | from PIL import Image |
| | import logging |
| |
|
| | app = Flask(__name__) |
| | CORS(app) |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class GazeInferenceServer: |
| | def __init__(self, model_path): |
| | """Initialize the gaze inference server.""" |
| | self.model_path = model_path |
| | self.model = None |
| | self.face_cascade = None |
| | self.eye_cascade = None |
| | |
| | |
| | self.face_size = (224, 224) |
| | self.eye_size = (80, 60) |
| | |
| | |
| | self._load_model() |
| | self._load_cascades() |
| | |
| | logger.info("Gaze inference server initialized") |
| | |
| | def _load_model(self): |
| | """Load the TensorFlow model.""" |
| | try: |
| | |
| | custom_objects = { |
| | 'euclidean_distance_metric': self._euclidean_distance_metric, |
| | 'mse': tf.keras.losses.MeanSquaredError(), |
| | } |
| | |
| | |
| | try: |
| | self.model = tf.keras.models.load_model( |
| | self.model_path, |
| | custom_objects=custom_objects |
| | ) |
| | except: |
| | |
| | self.model = tf.keras.models.load_model( |
| | self.model_path, |
| | compile=False |
| | ) |
| | self.model.compile( |
| | optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), |
| | loss='mse', |
| | metrics=['mae', self._euclidean_distance_metric] |
| | ) |
| | |
| | logger.info(f"Model loaded successfully from {self.model_path}") |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load model: {e}") |
| | raise |
| | |
| | @staticmethod |
| | def _euclidean_distance_metric(y_true, y_pred): |
| | """Custom metric for model.""" |
| | return tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1)) |
| | |
| | def _load_cascades(self): |
| | """Load Haar cascades for face and eye detection.""" |
| | self.face_cascade = cv2.CascadeClassifier( |
| | cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' |
| | ) |
| | self.eye_cascade = cv2.CascadeClassifier( |
| | cv2.data.haarcascades + 'haarcascade_eye.xml' |
| | ) |
| | logger.info("Haar cascades loaded") |
| | |
| | def extract_eye_regions(self, face_image): |
| | """Extract left and right eye regions from face image.""" |
| | gray = cv2.cvtColor(face_image, cv2.COLOR_BGR2GRAY) |
| | eyes = self.eye_cascade.detectMultiScale(gray, 1.1, 4) |
| | |
| | if len(eyes) >= 2: |
| | |
| | eyes = sorted(eyes, key=lambda e: e[0]) |
| | |
| | |
| | lx, ly, lw, lh = eyes[0] |
| | left_eye = face_image[ly:ly+lh, lx:lx+lw] |
| | left_eye = cv2.resize(left_eye, self.eye_size) |
| | |
| | rx, ry, rw, rh = eyes[1] |
| | right_eye = face_image[ry:ry+rh, rx:rx+rw] |
| | right_eye = cv2.resize(right_eye, self.eye_size) |
| | |
| | return left_eye, right_eye, True |
| | else: |
| | |
| | h, w = face_image.shape[:2] |
| | left_region = face_image[h//4:h//2, w//4:w//2] |
| | right_region = face_image[h//4:h//2, w//2:3*w//4] |
| | |
| | left_eye = cv2.resize(left_region, self.eye_size) |
| | right_eye = cv2.resize(right_region, self.eye_size) |
| | |
| | return left_eye, right_eye, False |
| | |
| | def preprocess_inputs(self, face, left_eye, right_eye): |
| | """Preprocess images for model input.""" |
| | |
| | face = face.astype(np.float32) / 255.0 |
| | left_eye = left_eye.astype(np.float32) / 255.0 |
| | right_eye = right_eye.astype(np.float32) / 255.0 |
| | |
| | |
| | face = np.expand_dims(face, axis=0) |
| | left_eye = np.expand_dims(left_eye, axis=0) |
| | right_eye = np.expand_dims(right_eye, axis=0) |
| | |
| | return [face, left_eye, right_eye] |
| | |
| | def predict_gaze(self, image_data, screen_width, screen_height): |
| | """Predict gaze position from image.""" |
| | start_time = time.time() |
| | |
| | try: |
| | |
| | image_bytes = base64.b64decode(image_data) |
| | image = Image.open(BytesIO(image_bytes)) |
| | image_np = np.array(image) |
| | |
| | |
| | if len(image_np.shape) == 3 and image_np.shape[2] == 3: |
| | image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
| | |
| | |
| | face_resized = cv2.resize(image_np, self.face_size) |
| | |
| | |
| | left_eye, right_eye, eyes_found = self.extract_eye_regions(face_resized) |
| | |
| | |
| | inputs = self.preprocess_inputs(face_resized, left_eye, right_eye) |
| | |
| | |
| | gaze_pred = self.model.predict(inputs, verbose=0)[0] |
| |
|
| | print(f"Raw gaze prediction: {gaze_pred}") |
| | |
| | |
| | gaze_x = float(gaze_pred[0] * screen_width) |
| | gaze_y = float(gaze_pred[1] * screen_height) |
| | |
| | |
| | gaze_x = max(0, min(gaze_x, screen_width)) |
| | gaze_y = max(0, min(gaze_y, screen_height)) |
| |
|
| | print(f"Predicted gaze position: ({gaze_x}, {gaze_y})") |
| | |
| | inference_time = (time.time() - start_time) * 1000 |
| | |
| | return { |
| | 'success': True, |
| | 'gaze_position': { |
| | 'x': gaze_x, |
| | 'y': gaze_y |
| | }, |
| | 'eyes_found': eyes_found, |
| | 'inference_time': inference_time |
| | } |
| | |
| | except Exception as e: |
| | logger.error(f"Prediction error: {e}") |
| | return { |
| | 'success': False, |
| | 'error': str(e) |
| | } |
| |
|
| | |
| | server = None |
| |
|
| | @app.route('/health', methods=['GET']) |
| | def health_check(): |
| | """Health check endpoint.""" |
| | return jsonify({ |
| | 'status': 'healthy', |
| | 'model_loaded': server is not None and server.model is not None |
| | }) |
| |
|
| | @app.route('/predict', methods=['POST']) |
| | def predict(): |
| | """Predict gaze position from image.""" |
| | try: |
| | data = request.json |
| | |
| | if not data or 'image' not in data: |
| | return jsonify({ |
| | 'success': False, |
| | 'error': 'No image data provided' |
| | }), 400 |
| | |
| | |
| | image_data = data['image'] |
| | screen_width = data.get('screen_width', 1920) |
| | screen_height = data.get('screen_height', 1080) |
| | |
| | |
| | result = server.predict_gaze(image_data, screen_width, screen_height) |
| | |
| | return jsonify(result) |
| | |
| | except Exception as e: |
| | logger.error(f"Prediction endpoint error: {e}") |
| | return jsonify({ |
| | 'success': False, |
| | 'error': str(e) |
| | }), 500 |
| |
|
| | @app.route('/calibrate', methods=['POST']) |
| | def calibrate(): |
| | """Calibration endpoint (placeholder for future implementation).""" |
| | return jsonify({ |
| | 'success': True, |
| | 'message': 'Calibration not yet implemented' |
| | }) |
| |
|
| | def create_app(model_path='best_gaze_model.h5'): |
| | """Create and configure the Flask app.""" |
| | global server |
| | |
| | |
| | server = GazeInferenceServer(model_path) |
| | |
| | return app |
| |
|
| | if __name__ == '__main__': |
| | import argparse |
| | import os |
| | |
| | |
| | parser = argparse.ArgumentParser(description='Gaze Inference Server') |
| | parser.add_argument( |
| | '--model', |
| | type=str, |
| | default='best_gaze_model.h5', |
| | help='Path to the trained model' |
| | ) |
| | parser.add_argument( |
| | '--port', |
| | type=int, |
| | default=5000, |
| | help='Port to run the server on' |
| | ) |
| | parser.add_argument( |
| | '--host', |
| | type=str, |
| | default='0.0.0.0', |
| | help='Host to run the server on' |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if not os.path.exists(args.model): |
| | print(f"Error: Model file '{args.model}' not found!") |
| | exit(1) |
| | |
| | |
| | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
| | |
| | |
| | app = create_app(args.model) |
| | |
| | |
| | print(f"\n{'='*50}") |
| | print(f"Starting Gaze Inference Server") |
| | print(f"Model: {args.model}") |
| | print(f"Server: http://{args.host}:{args.port}") |
| | print(f"{'='*50}\n") |
| | |
| | app.run( |
| | host=args.host, |
| | port=args.port, |
| | debug=False, |
| | threaded=True |
| | ) |
| |
|