gaze_test / training_deepseek.py
Olof Astrand
Added training
3b4813c
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Initialize eye detector
eye_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_eye.xml')
# Global variable to store last valid eye region
last_valid_eye_region = None
def extract_eye_regions(image):
"""
Extract and validate eye regions from an image using Haar cascades.
Returns a cropped image focusing on the eyes; always outputs exactly (60,80,3).
Validates that:
1. Exactly 2 eyes are detected
2. Eyes are at approximately the same height (vertical position)
3. Eye regions are of reasonable size
Remembers and reuses the last valid eye region if current detection fails.
"""
global last_valid_eye_region
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
eyes = eye_cascade.detectMultiScale(gray, 1.1, 4)
valid_eyes = []
# Filter and validate eyes
if len(eyes) >= 2:
# Sort by size and take the two largest
eyes = sorted(eyes, key=lambda x: x[2]*x[3], reverse=True)[:2]
# Check if eyes are at similar height (within 20% of image height)
height_threshold = image.shape[0] * 0.2
if abs(eyes[0][1] - eyes[1][1]) < height_threshold:
# Check if eye sizes are reasonable (between 10x10 and half image size)
min_eye_size = 10
max_eye_width = image.shape[1] // 2
max_eye_height = image.shape[0] // 2
eye1_valid = (eyes[0][2] >= min_eye_size and eyes[0][3] >= min_eye_size and
eyes[0][2] <= max_eye_width and eyes[0][3] <= max_eye_height)
eye2_valid = (eyes[1][2] >= min_eye_size and eyes[1][3] >= min_eye_size and
eyes[1][2] <= max_eye_width and eyes[1][3] <= max_eye_height)
if eye1_valid and eye2_valid:
valid_eyes = eyes
if len(valid_eyes) == 2:
# Calculate bounding box around both eyes
x_min = min(valid_eyes[0][0], valid_eyes[1][0])
y_min = min(valid_eyes[0][1], valid_eyes[1][1])
x_max = max(valid_eyes[0][0]+valid_eyes[0][2], valid_eyes[1][0]+valid_eyes[1][2])
y_max = max(valid_eyes[0][1]+valid_eyes[0][3], valid_eyes[1][1]+valid_eyes[1][3])
# Add padding (20% of eye width)
padding = int(valid_eyes[0][2] + valid_eyes[1][2]) / 2 * 0.2
x_min = max(0, int(x_min - padding))
y_min = max(0, int(y_min - padding))
x_max = min(image.shape[1], int(x_max + padding))
y_max = min(image.shape[0], int(y_max + padding))
eye_region = image[y_min:y_max, x_min:x_max]
# Always resize to (80, 60)
eye_region = cv2.resize(eye_region, (80, 60))
last_valid_eye_region = eye_region
return eye_region
# Fallback to last valid eye region if available
if last_valid_eye_region is not None:
return last_valid_eye_region
# Final fallback if no eyes detected and no last valid region
return cv2.resize(image, (80, 60))
def load_data_batch(data_dir, batch_size=100):
"""Load data with eye region extraction."""
metadata_path = os.path.join(data_dir, 'metadata.json')
images_dir = os.path.join(data_dir, 'images')
# Load metadata
with open(metadata_path, 'r') as f:
metadata = json.load(f)
X = []
y = []
for i, data_point in enumerate(metadata['data_points']):
if i >= batch_size and batch_size > 0:
break
# Load image
img_path = os.path.join(images_dir, data_point['image'])
img = cv2.imread(img_path)
if img is None:
continue
# Convert to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Extract eye region
eye_img = extract_eye_regions(img)
# Normalize
eye_img = eye_img.astype('float32') / 255.0
X.append(eye_img)
y.append([data_point['screen_x'] / metadata['screen_width'],
data_point['screen_y'] / metadata['screen_height']])
return np.array(X), np.array(y)
def create_eye_centric_model(input_shape=(60, 80, 3)):
"""Create a CNN model optimized for eye regions."""
inputs = keras.Input(shape=input_shape)
# Enhanced architecture for eye regions
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(2, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='mse',
metrics=['mae'])
return model
def main():
# Configuration
data_dir = "gaze_data_20250602_140239" # Update with your directory
batch_size = -1 # -1 for all data, or specify a number
epochs = 50
model_save_path = "gaze_model_eye_centric.h5"
# Load data with eye region extraction
print("Loading and processing data with eye region extraction...")
X, y = load_data_batch(data_dir, batch_size)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
print(f"\nData shapes:")
print(f"Training: {X_train.shape} {y_train.shape}")
print(f"Validation: {X_val.shape} {y_val.shape}")
print(f"Test: {X_test.shape} {y_test.shape}")
# Create enhanced model
model = create_eye_centric_model()
model.summary()
# Callbacks
callbacks = [
keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
keras.callbacks.ModelCheckpoint(model_save_path, save_best_only=True),
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
]
# Train model
print("\nTraining model...")
history = model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=epochs,
batch_size=32,
callbacks=callbacks,shuffle=True
)
# Evaluate
print("\nEvaluating on test set...")
test_loss, test_mae = model.evaluate(X_test, y_test)
print(f"Test MAE: {test_mae:.4f}")
# Save final model
model.save(model_save_path)
print(f"\nModel saved to {model_save_path}")
if __name__ == "__main__":
import json # Import moved here to avoid confusion
main()