Olof Astrand commited on
Commit ·
42c48e6
1
Parent(s): 6bd5bc0
First version
Browse files- README.md +29 -0
- collector.html +489 -0
- collector.py +185 -0
- dataset_converter.py +417 -0
- inference.py +403 -0
- readme.txt +272 -0
- requirements.txt +6 -0
README.md
CHANGED
|
@@ -1,3 +1,32 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
| 4 |
+
This is a first test on Huggingface to test and learn
|
| 5 |
+
The project is mainly created with Claude 4, ChatGPT 4 and DeepSeek
|
| 6 |
+
|
| 7 |
+
Create a virtual python env or use Conda.
|
| 8 |
+
gaze_env
|
| 9 |
+
|
| 10 |
+
Files:
|
| 11 |
+
|
| 12 |
+
Creating a dataset
|
| 13 |
+
==================
|
| 14 |
+
collector.py
|
| 15 |
+
collector.html
|
| 16 |
+
When creating a dataset in the browser you will have to convert ti with
|
| 17 |
+
convert.py
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Training from web based dataset
|
| 21 |
+
=============
|
| 22 |
+
training.py
|
| 23 |
+
|
| 24 |
+
Training from OpenCV created dataset
|
| 25 |
+
==============
|
| 26 |
+
training_deepseek.py
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
Inference
|
| 30 |
+
==========
|
| 31 |
+
inference.py
|
| 32 |
+
This does not work in a wsl environment
|
collector.html
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Gaze Data Collector</title>
|
| 7 |
+
<style>
|
| 8 |
+
* {
|
| 9 |
+
margin: 0;
|
| 10 |
+
padding: 0;
|
| 11 |
+
box-sizing: border-box;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
body {
|
| 15 |
+
font-family: Arial, sans-serif;
|
| 16 |
+
background: #000;
|
| 17 |
+
color: white;
|
| 18 |
+
overflow: hidden;
|
| 19 |
+
height: 100vh;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
#gameContainer {
|
| 23 |
+
position: relative;
|
| 24 |
+
width: 100vw;
|
| 25 |
+
height: 100vh;
|
| 26 |
+
background: black;
|
| 27 |
+
cursor: none;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
#cross {
|
| 31 |
+
position: absolute;
|
| 32 |
+
width: 60px;
|
| 33 |
+
height: 60px;
|
| 34 |
+
pointer-events: none;
|
| 35 |
+
z-index: 10;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.cross-line {
|
| 39 |
+
position: absolute;
|
| 40 |
+
background: #00ff00;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.cross-horizontal {
|
| 44 |
+
width: 60px;
|
| 45 |
+
height: 3px;
|
| 46 |
+
top: 50%;
|
| 47 |
+
left: 0;
|
| 48 |
+
transform: translateY(-50%);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
.cross-vertical {
|
| 52 |
+
width: 3px;
|
| 53 |
+
height: 60px;
|
| 54 |
+
left: 50%;
|
| 55 |
+
top: 0;
|
| 56 |
+
transform: translateX(-50%);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.cross-center {
|
| 60 |
+
position: absolute;
|
| 61 |
+
width: 10px;
|
| 62 |
+
height: 10px;
|
| 63 |
+
background: #ff0000;
|
| 64 |
+
border-radius: 50%;
|
| 65 |
+
top: 50%;
|
| 66 |
+
left: 50%;
|
| 67 |
+
transform: translate(-50%, -50%);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
#videoContainer {
|
| 71 |
+
position: fixed;
|
| 72 |
+
top: 50%;
|
| 73 |
+
left: 50%;
|
| 74 |
+
transform: translate(-50%, -50%);
|
| 75 |
+
width: 160px;
|
| 76 |
+
height: 120px;
|
| 77 |
+
border: 2px solid #00ff00;
|
| 78 |
+
background: rgba(0, 0, 0, 0.8);
|
| 79 |
+
z-index: 100;
|
| 80 |
+
border-radius: 8px;
|
| 81 |
+
overflow: hidden;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
#video {
|
| 85 |
+
width: 400%;
|
| 86 |
+
height: 400%;
|
| 87 |
+
object-fit: cover;
|
| 88 |
+
transform: translate(-37.5%, -37.5%);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#controls {
|
| 92 |
+
position: fixed;
|
| 93 |
+
top: 20px;
|
| 94 |
+
left: 20px;
|
| 95 |
+
z-index: 100;
|
| 96 |
+
background: rgba(0, 0, 0, 0.8);
|
| 97 |
+
padding: 20px;
|
| 98 |
+
border-radius: 8px;
|
| 99 |
+
border: 1px solid #333;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.control-button {
|
| 103 |
+
background: #007bff;
|
| 104 |
+
color: white;
|
| 105 |
+
border: none;
|
| 106 |
+
padding: 10px 20px;
|
| 107 |
+
margin: 5px;
|
| 108 |
+
border-radius: 5px;
|
| 109 |
+
cursor: pointer;
|
| 110 |
+
font-size: 14px;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.control-button:hover {
|
| 114 |
+
background: #0056b3;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
.control-button:disabled {
|
| 118 |
+
background: #666;
|
| 119 |
+
cursor: not-allowed;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#status {
|
| 123 |
+
margin-top: 10px;
|
| 124 |
+
font-size: 14px;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
#timer {
|
| 128 |
+
font-size: 18px;
|
| 129 |
+
font-weight: bold;
|
| 130 |
+
color: #00ff00;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
.instructions {
|
| 134 |
+
position: fixed;
|
| 135 |
+
top: 50%;
|
| 136 |
+
left: 50%;
|
| 137 |
+
transform: translate(-50%, -50%);
|
| 138 |
+
text-align: center;
|
| 139 |
+
font-size: 24px;
|
| 140 |
+
z-index: 50;
|
| 141 |
+
background: rgba(0, 0, 0, 0.8);
|
| 142 |
+
padding: 30px;
|
| 143 |
+
border-radius: 10px;
|
| 144 |
+
border: 1px solid #333;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.hidden {
|
| 148 |
+
display: none;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
#downloadSection {
|
| 152 |
+
margin-top: 15px;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
#downloadButton {
|
| 156 |
+
background: #28a745;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
#downloadButton:hover {
|
| 160 |
+
background: #1e7e34;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.eye-guide {
|
| 164 |
+
position: absolute;
|
| 165 |
+
top: 5px;
|
| 166 |
+
left: 5px;
|
| 167 |
+
right: 5px;
|
| 168 |
+
bottom: 5px;
|
| 169 |
+
border: 1px dashed #00ff00;
|
| 170 |
+
border-radius: 4px;
|
| 171 |
+
opacity: 0.7;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.eye-guide::before {
|
| 175 |
+
content: "Center your face here";
|
| 176 |
+
position: absolute;
|
| 177 |
+
top: -20px;
|
| 178 |
+
left: 0;
|
| 179 |
+
font-size: 10px;
|
| 180 |
+
color: #00ff00;
|
| 181 |
+
}
|
| 182 |
+
</style>
|
| 183 |
+
</head>
|
| 184 |
+
<body>
|
| 185 |
+
<div id="gameContainer">
|
| 186 |
+
<div id="cross">
|
| 187 |
+
<div class="cross-line cross-horizontal"></div>
|
| 188 |
+
<div class="cross-line cross-vertical"></div>
|
| 189 |
+
<div class="cross-center"></div>
|
| 190 |
+
</div>
|
| 191 |
+
|
| 192 |
+
<div id="instructions" class="instructions">
|
| 193 |
+
<h2>Gaze Data Collector</h2>
|
| 194 |
+
<p>Follow the green cross with your eyes</p>
|
| 195 |
+
<p>Position your eye in the box on the top right</p>
|
| 196 |
+
<p>Press START to begin data collection</p>
|
| 197 |
+
<p>Collection will run for 5 minutes</p>
|
| 198 |
+
</div>
|
| 199 |
+
</div>
|
| 200 |
+
|
| 201 |
+
<div id="videoContainer">
|
| 202 |
+
<video id="video" autoplay muted playsinline></video>
|
| 203 |
+
<div class="eye-guide"></div>
|
| 204 |
+
</div>
|
| 205 |
+
|
| 206 |
+
<div id="controls">
|
| 207 |
+
<button id="startButton" class="control-button">START COLLECTION</button>
|
| 208 |
+
<button id="stopButton" class="control-button" disabled>STOP</button>
|
| 209 |
+
<div id="status">
|
| 210 |
+
<div>Status: <span id="statusText">Ready</span></div>
|
| 211 |
+
<div>Timer: <span id="timer">00:00</span></div>
|
| 212 |
+
<div>Frames: <span id="frameCount">0</span></div>
|
| 213 |
+
</div>
|
| 214 |
+
<div id="downloadSection" class="hidden">
|
| 215 |
+
<button id="downloadButton" class="control-button">DOWNLOAD DATA</button>
|
| 216 |
+
</div>
|
| 217 |
+
</div>
|
| 218 |
+
|
| 219 |
+
<script>
|
| 220 |
+
class GazeDataCollector {
|
| 221 |
+
constructor() {
|
| 222 |
+
this.screenWidth = window.innerWidth;
|
| 223 |
+
this.screenHeight = window.innerHeight;
|
| 224 |
+
this.crossSize = 30;
|
| 225 |
+
this.speed = 3;
|
| 226 |
+
this.dataPoints = [];
|
| 227 |
+
this.collecting = false;
|
| 228 |
+
this.startTime = null;
|
| 229 |
+
this.frameCount = 0;
|
| 230 |
+
this.collectionDuration = 300; // 5 minutes in seconds
|
| 231 |
+
|
| 232 |
+
// Cross position and movement
|
| 233 |
+
this.x = this.screenWidth / 2;
|
| 234 |
+
this.y = 50;
|
| 235 |
+
this.directionX = 1;
|
| 236 |
+
this.directionY = 1;
|
| 237 |
+
|
| 238 |
+
this.initializeElements();
|
| 239 |
+
this.initializeCamera();
|
| 240 |
+
this.bindEvents();
|
| 241 |
+
this.animationLoop();
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
initializeElements() {
|
| 245 |
+
this.cross = document.getElementById('cross');
|
| 246 |
+
this.video = document.getElementById('video');
|
| 247 |
+
this.startButton = document.getElementById('startButton');
|
| 248 |
+
this.stopButton = document.getElementById('stopButton');
|
| 249 |
+
this.statusText = document.getElementById('statusText');
|
| 250 |
+
this.timer = document.getElementById('timer');
|
| 251 |
+
this.frameCountEl = document.getElementById('frameCount');
|
| 252 |
+
this.instructions = document.getElementById('instructions');
|
| 253 |
+
this.downloadSection = document.getElementById('downloadSection');
|
| 254 |
+
this.downloadButton = document.getElementById('downloadButton');
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
async initializeCamera() {
|
| 258 |
+
try {
|
| 259 |
+
const stream = await navigator.mediaDevices.getUserMedia({
|
| 260 |
+
video: {
|
| 261 |
+
width: 640,
|
| 262 |
+
height: 480,
|
| 263 |
+
facingMode: 'user'
|
| 264 |
+
}
|
| 265 |
+
});
|
| 266 |
+
this.video.srcObject = stream;
|
| 267 |
+
this.statusText.textContent = 'Camera ready';
|
| 268 |
+
} catch (error) {
|
| 269 |
+
console.error('Error accessing camera:', error);
|
| 270 |
+
this.statusText.textContent = 'Camera error';
|
| 271 |
+
alert('Unable to access camera. Please ensure you have granted camera permissions.');
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
bindEvents() {
|
| 276 |
+
this.startButton.addEventListener('click', () => this.startCollection());
|
| 277 |
+
this.stopButton.addEventListener('click', () => this.stopCollection());
|
| 278 |
+
this.downloadButton.addEventListener('click', () => this.downloadData());
|
| 279 |
+
|
| 280 |
+
// Handle window resize
|
| 281 |
+
window.addEventListener('resize', () => {
|
| 282 |
+
this.screenWidth = window.innerWidth;
|
| 283 |
+
this.screenHeight = window.innerHeight;
|
| 284 |
+
});
|
| 285 |
+
|
| 286 |
+
// Keyboard shortcuts
|
| 287 |
+
document.addEventListener('keydown', (e) => {
|
| 288 |
+
if (e.code === 'Space') {
|
| 289 |
+
e.preventDefault();
|
| 290 |
+
if (!this.collecting) {
|
| 291 |
+
this.startCollection();
|
| 292 |
+
}
|
| 293 |
+
} else if (e.code === 'Escape') {
|
| 294 |
+
this.stopCollection();
|
| 295 |
+
}
|
| 296 |
+
});
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
updateCrossPosition() {
|
| 300 |
+
// Move horizontally
|
| 301 |
+
this.x += this.directionX * this.speed;
|
| 302 |
+
|
| 303 |
+
// Check horizontal bounds and move down when reaching edge
|
| 304 |
+
if (this.x <= this.crossSize || this.x >= this.screenWidth - this.crossSize) {
|
| 305 |
+
this.directionX *= -1; // Reverse horizontal direction
|
| 306 |
+
this.y += 50; // Move down by 50 pixels
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
// Check if we've reached the bottom
|
| 310 |
+
if (this.y >= this.screenHeight - this.crossSize) {
|
| 311 |
+
this.y = this.crossSize; // Reset to top
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Ensure cross stays within bounds
|
| 315 |
+
this.x = Math.max(this.crossSize, Math.min(this.x, this.screenWidth - this.crossSize));
|
| 316 |
+
this.y = Math.max(this.crossSize, Math.min(this.y, this.screenHeight - this.crossSize));
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
drawCross() {
|
| 320 |
+
this.cross.style.left = (this.x - 30) + 'px';
|
| 321 |
+
this.cross.style.top = (this.y - 30) + 'px';
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
captureFrame() {
|
| 325 |
+
// Create a canvas to capture the cropped video frame
|
| 326 |
+
const canvas = document.createElement('canvas');
|
| 327 |
+
const ctx = canvas.getContext('2d');
|
| 328 |
+
|
| 329 |
+
// Set final output size to 60x80 pixels
|
| 330 |
+
canvas.width = 60;
|
| 331 |
+
canvas.height = 80;
|
| 332 |
+
|
| 333 |
+
// Get video dimensions
|
| 334 |
+
const videoWidth = this.video.videoWidth || 640;
|
| 335 |
+
const videoHeight = this.video.videoHeight || 480;
|
| 336 |
+
|
| 337 |
+
// Calculate crop area for face region (center area)
|
| 338 |
+
// We want to crop approximately the center 25% of the video for tight face crop
|
| 339 |
+
const cropWidth = videoWidth * 0.25;
|
| 340 |
+
const cropHeight = videoHeight * 0.25;
|
| 341 |
+
const cropX = (videoWidth - cropWidth) / 2;
|
| 342 |
+
const cropY = (videoHeight - cropHeight) / 2;
|
| 343 |
+
|
| 344 |
+
// Draw the cropped area scaled to 60x80
|
| 345 |
+
ctx.drawImage(
|
| 346 |
+
this.video,
|
| 347 |
+
cropX, cropY, cropWidth, cropHeight, // Source crop area
|
| 348 |
+
0, 0, 60, 80 // Destination size
|
| 349 |
+
);
|
| 350 |
+
|
| 351 |
+
// Convert to base64 image data with higher compression
|
| 352 |
+
const imageData = canvas.toDataURL('image/jpeg', 0.6);
|
| 353 |
+
|
| 354 |
+
return imageData;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
collectDataPoint() {
|
| 358 |
+
if (!this.collecting) return;
|
| 359 |
+
|
| 360 |
+
const currentTime = Date.now();
|
| 361 |
+
const timestamp = (currentTime - this.startTime) / 1000; // Convert to seconds
|
| 362 |
+
|
| 363 |
+
// Capture video frame
|
| 364 |
+
const imageData = this.captureFrame();
|
| 365 |
+
|
| 366 |
+
// Store data point
|
| 367 |
+
const dataPoint = {
|
| 368 |
+
frame: this.frameCount,
|
| 369 |
+
timestamp: timestamp,
|
| 370 |
+
screen_x: this.x,
|
| 371 |
+
screen_y: this.y,
|
| 372 |
+
screen_width: this.screenWidth,
|
| 373 |
+
screen_height: this.screenHeight,
|
| 374 |
+
image_data: imageData
|
| 375 |
+
};
|
| 376 |
+
|
| 377 |
+
this.dataPoints.push(dataPoint);
|
| 378 |
+
this.frameCount++;
|
| 379 |
+
|
| 380 |
+
// Update UI
|
| 381 |
+
this.frameCountEl.textContent = this.frameCount;
|
| 382 |
+
|
| 383 |
+
// Check if collection time is up
|
| 384 |
+
if (timestamp >= this.collectionDuration) {
|
| 385 |
+
this.stopCollection();
|
| 386 |
+
alert('5 minutes completed! Data collection finished.');
|
| 387 |
+
}
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
startCollection() {
|
| 391 |
+
this.collecting = true;
|
| 392 |
+
this.startTime = Date.now();
|
| 393 |
+
this.frameCount = 0;
|
| 394 |
+
this.dataPoints = [];
|
| 395 |
+
|
| 396 |
+
this.startButton.disabled = true;
|
| 397 |
+
this.stopButton.disabled = false;
|
| 398 |
+
this.statusText.textContent = 'Collecting...';
|
| 399 |
+
this.instructions.classList.add('hidden');
|
| 400 |
+
|
| 401 |
+
console.log('Data collection started!');
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
stopCollection() {
|
| 405 |
+
this.collecting = false;
|
| 406 |
+
|
| 407 |
+
this.startButton.disabled = false;
|
| 408 |
+
this.stopButton.disabled = true;
|
| 409 |
+
this.statusText.textContent = 'Collection stopped';
|
| 410 |
+
this.downloadSection.classList.remove('hidden');
|
| 411 |
+
|
| 412 |
+
console.log('Data collection stopped!');
|
| 413 |
+
console.log(`Total frames collected: ${this.frameCount}`);
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
updateTimer() {
|
| 417 |
+
if (!this.collecting) return;
|
| 418 |
+
|
| 419 |
+
const elapsed = (Date.now() - this.startTime) / 1000;
|
| 420 |
+
const remaining = Math.max(0, this.collectionDuration - elapsed);
|
| 421 |
+
|
| 422 |
+
const minutes = Math.floor(elapsed / 60);
|
| 423 |
+
const seconds = Math.floor(elapsed % 60);
|
| 424 |
+
|
| 425 |
+
this.timer.textContent = `${minutes.toString().padStart(2, '0')}:${seconds.toString().padStart(2, '0')}`;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
downloadData() {
|
| 429 |
+
if (this.dataPoints.length === 0) {
|
| 430 |
+
alert('No data to download!');
|
| 431 |
+
return;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
// Create metadata
|
| 435 |
+
const metadata = {
|
| 436 |
+
screen_width: this.screenWidth,
|
| 437 |
+
screen_height: this.screenHeight,
|
| 438 |
+
cross_size: this.crossSize,
|
| 439 |
+
speed: this.speed,
|
| 440 |
+
total_frames: this.dataPoints.length,
|
| 441 |
+
collection_duration: this.collectionDuration,
|
| 442 |
+
timestamp: new Date().toISOString(),
|
| 443 |
+
data_points: this.dataPoints
|
| 444 |
+
};
|
| 445 |
+
|
| 446 |
+
// Convert to JSON
|
| 447 |
+
const jsonData = JSON.stringify(metadata, null, 2);
|
| 448 |
+
|
| 449 |
+
// Create download link
|
| 450 |
+
const blob = new Blob([jsonData], { type: 'application/json' });
|
| 451 |
+
const url = URL.createObjectURL(blob);
|
| 452 |
+
|
| 453 |
+
const a = document.createElement('a');
|
| 454 |
+
a.href = url;
|
| 455 |
+
a.download = `gaze_data_${new Date().toISOString().replace(/[:.]/g, '-')}.json`;
|
| 456 |
+
document.body.appendChild(a);
|
| 457 |
+
a.click();
|
| 458 |
+
document.body.removeChild(a);
|
| 459 |
+
|
| 460 |
+
URL.revokeObjectURL(url);
|
| 461 |
+
|
| 462 |
+
console.log('Data downloaded successfully!');
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
animationLoop() {
|
| 466 |
+
// Update cross position if collecting
|
| 467 |
+
if (this.collecting) {
|
| 468 |
+
this.updateCrossPosition();
|
| 469 |
+
this.collectDataPoint();
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
// Always draw the cross
|
| 473 |
+
this.drawCross();
|
| 474 |
+
|
| 475 |
+
// Update timer
|
| 476 |
+
this.updateTimer();
|
| 477 |
+
|
| 478 |
+
// Continue animation loop
|
| 479 |
+
requestAnimationFrame(() => this.animationLoop());
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// Initialize the collector when page loads
|
| 484 |
+
window.addEventListener('load', () => {
|
| 485 |
+
new GazeDataCollector();
|
| 486 |
+
});
|
| 487 |
+
</script>
|
| 488 |
+
</body>
|
| 489 |
+
</html>
|
collector.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
class GazeDataCollector:
|
| 9 |
+
def __init__(self, screen_width=1920, screen_height=1080, cross_size=30, speed=5):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the gaze data collector.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
screen_width: Width of the screen
|
| 15 |
+
screen_height: Height of the screen
|
| 16 |
+
cross_size: Size of the cross marker
|
| 17 |
+
speed: Speed of cross movement (pixels per frame)
|
| 18 |
+
"""
|
| 19 |
+
self.screen_width = screen_width
|
| 20 |
+
self.screen_height = screen_height
|
| 21 |
+
self.cross_size = cross_size
|
| 22 |
+
self.speed = speed
|
| 23 |
+
self.data_points = []
|
| 24 |
+
self.collecting = False
|
| 25 |
+
|
| 26 |
+
# Create output directory
|
| 27 |
+
self.output_dir = f"gaze_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 28 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 29 |
+
os.makedirs(os.path.join(self.output_dir, "images"), exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# Initialize camera
|
| 32 |
+
self.cap = cv2.VideoCapture(0)
|
| 33 |
+
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
| 34 |
+
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
| 35 |
+
|
| 36 |
+
# Cross position
|
| 37 |
+
self.x = screen_width // 2
|
| 38 |
+
self.y = 0
|
| 39 |
+
self.direction_x = 1
|
| 40 |
+
self.direction_y = 1
|
| 41 |
+
|
| 42 |
+
# Create fullscreen window
|
| 43 |
+
cv2.namedWindow('Gaze Target', cv2.WINDOW_NORMAL)
|
| 44 |
+
cv2.setWindowProperty('Gaze Target', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
|
| 45 |
+
|
| 46 |
+
def draw_cross(self, img, x, y):
|
| 47 |
+
"""Draw a cross at the specified position."""
|
| 48 |
+
# Ensure cross stays within screen bounds
|
| 49 |
+
x = max(self.cross_size, min(x, self.screen_width - self.cross_size))
|
| 50 |
+
y = max(self.cross_size, min(y, self.screen_height - self.cross_size))
|
| 51 |
+
|
| 52 |
+
# Draw cross
|
| 53 |
+
cv2.line(img, (x - self.cross_size, y), (x + self.cross_size, y), (0, 255, 0), 3)
|
| 54 |
+
cv2.line(img, (x, y - self.cross_size), (x, y + self.cross_size), (0, 255, 0), 3)
|
| 55 |
+
|
| 56 |
+
# Draw center dot
|
| 57 |
+
cv2.circle(img, (x, y), 5, (0, 0, 255), -1)
|
| 58 |
+
|
| 59 |
+
return x, y
|
| 60 |
+
|
| 61 |
+
def update_position(self):
|
| 62 |
+
"""Update cross position with continuous movement pattern."""
|
| 63 |
+
# Move horizontally
|
| 64 |
+
self.x += self.direction_x * self.speed
|
| 65 |
+
|
| 66 |
+
# Check horizontal bounds and move down when reaching edge
|
| 67 |
+
if self.x <= self.cross_size or self.x >= self.screen_width - self.cross_size:
|
| 68 |
+
self.direction_x *= -1 # Reverse horizontal direction
|
| 69 |
+
self.y += 50 # Move down by 50 pixels
|
| 70 |
+
|
| 71 |
+
# Check if we've reached the bottom
|
| 72 |
+
if self.y >= self.screen_height - self.cross_size:
|
| 73 |
+
self.y = self.cross_size # Reset to top
|
| 74 |
+
|
| 75 |
+
def collect_data(self):
|
| 76 |
+
"""Main data collection loop."""
|
| 77 |
+
print("Press SPACE to start data collection")
|
| 78 |
+
print("Press ESC to stop and save data")
|
| 79 |
+
print("The cross will move across the screen in a pattern")
|
| 80 |
+
print("Please follow the cross with your eyes")
|
| 81 |
+
|
| 82 |
+
frame_count = 0
|
| 83 |
+
start_time = None
|
| 84 |
+
|
| 85 |
+
while True:
|
| 86 |
+
# Create black background
|
| 87 |
+
screen = np.zeros((self.screen_height, self.screen_width, 3), dtype=np.uint8)
|
| 88 |
+
|
| 89 |
+
# Update and draw cross
|
| 90 |
+
if self.collecting:
|
| 91 |
+
self.update_position()
|
| 92 |
+
|
| 93 |
+
self.x, self.y = self.draw_cross(screen, self.x, self.y)
|
| 94 |
+
|
| 95 |
+
# Capture webcam frame
|
| 96 |
+
ret, frame = self.cap.read()
|
| 97 |
+
if not ret:
|
| 98 |
+
print("Failed to capture frame")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Show instructions when not collecting
|
| 102 |
+
if not self.collecting:
|
| 103 |
+
cv2.putText(screen, "Press SPACE to start", (self.screen_width//2 - 200, self.screen_height//2),
|
| 104 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 105 |
+
else:
|
| 106 |
+
# Save data point
|
| 107 |
+
timestamp = time.time() - start_time
|
| 108 |
+
image_filename = f"frame_{frame_count:06d}.jpg"
|
| 109 |
+
image_path = os.path.join(self.output_dir, "images", image_filename)
|
| 110 |
+
|
| 111 |
+
# Save image
|
| 112 |
+
cv2.imwrite(image_path, frame)
|
| 113 |
+
|
| 114 |
+
# Store data point
|
| 115 |
+
self.data_points.append({
|
| 116 |
+
"frame": frame_count,
|
| 117 |
+
"timestamp": timestamp,
|
| 118 |
+
"screen_x": self.x,
|
| 119 |
+
"screen_y": self.y,
|
| 120 |
+
"image": image_filename
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
frame_count += 1
|
| 124 |
+
|
| 125 |
+
# Show progress
|
| 126 |
+
elapsed = timestamp
|
| 127 |
+
remaining = max(0, 300 - elapsed) # 5 minutes = 300 seconds
|
| 128 |
+
cv2.putText(screen, f"Time: {elapsed:.1f}s / 300s", (10, 30),
|
| 129 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 130 |
+
|
| 131 |
+
# Stop after 5 minutes
|
| 132 |
+
if elapsed >= 300:
|
| 133 |
+
print("5 minutes reached. Stopping collection.")
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
# Display the screen
|
| 137 |
+
cv2.imshow('Gaze Target', screen)
|
| 138 |
+
|
| 139 |
+
# Handle key presses
|
| 140 |
+
key = cv2.waitKey(1) & 0xFF
|
| 141 |
+
if key == 27: # ESC
|
| 142 |
+
break
|
| 143 |
+
elif key == 32: # SPACE
|
| 144 |
+
if not self.collecting:
|
| 145 |
+
self.collecting = True
|
| 146 |
+
start_time = time.time()
|
| 147 |
+
print("Data collection started!")
|
| 148 |
+
|
| 149 |
+
# Cleanup
|
| 150 |
+
self.cap.release()
|
| 151 |
+
cv2.destroyAllWindows()
|
| 152 |
+
|
| 153 |
+
# Save metadata
|
| 154 |
+
self.save_metadata()
|
| 155 |
+
|
| 156 |
+
def save_metadata(self):
|
| 157 |
+
"""Save collected data points to JSON file."""
|
| 158 |
+
metadata = {
|
| 159 |
+
"screen_width": self.screen_width,
|
| 160 |
+
"screen_height": self.screen_height,
|
| 161 |
+
"cross_size": self.cross_size,
|
| 162 |
+
"speed": self.speed,
|
| 163 |
+
"total_frames": len(self.data_points),
|
| 164 |
+
"data_points": self.data_points
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
with open(os.path.join(self.output_dir, "metadata.json"), 'w') as f:
|
| 168 |
+
json.dump(metadata, f, indent=2)
|
| 169 |
+
|
| 170 |
+
print(f"\nData collection complete!")
|
| 171 |
+
print(f"Total frames collected: {len(self.data_points)}")
|
| 172 |
+
print(f"Data saved to: {self.output_dir}")
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
# Create data collector with default screen resolution
|
| 176 |
+
# Adjust these values to match your screen
|
| 177 |
+
collector = GazeDataCollector(
|
| 178 |
+
screen_width=1920,
|
| 179 |
+
screen_height=1080,
|
| 180 |
+
cross_size=30,
|
| 181 |
+
speed=5
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Start data collection
|
| 185 |
+
collector.collect_data()
|
dataset_converter.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gaze Dataset Converter for TensorFlow Training
|
| 4 |
+
|
| 5 |
+
This script converts the collected gaze data from the web collector
|
| 6 |
+
into TensorFlow-compatible datasets for training neural networks.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python dataset_converter.py --input data_folder --output processed_dataset
|
| 10 |
+
python dataset_converter.py --json gaze_data.json --output my_dataset
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import json
|
| 15 |
+
import base64
|
| 16 |
+
import argparse
|
| 17 |
+
import numpy as np
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import io
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
from sklearn.model_selection import train_test_split
|
| 22 |
+
from sklearn.preprocessing import StandardScaler
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
import cv2
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
import logging
|
| 27 |
+
|
| 28 |
+
# Setup logging
|
| 29 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
class GazeDatasetConverter:
|
| 33 |
+
def __init__(self, output_dir="processed_dataset", test_size=0.2, val_size=0.1):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset converter.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
output_dir: Directory to save processed dataset
|
| 39 |
+
test_size: Proportion of data for testing (0.2 = 20%)
|
| 40 |
+
val_size: Proportion of data for validation (0.1 = 10%)
|
| 41 |
+
"""
|
| 42 |
+
self.output_dir = Path(output_dir)
|
| 43 |
+
self.test_size = test_size
|
| 44 |
+
self.val_size = val_size
|
| 45 |
+
self.image_size = (80, 60) # Height, Width (as expected by the collector)
|
| 46 |
+
|
| 47 |
+
# Create output directories
|
| 48 |
+
self.output_dir.mkdir(exist_ok=True)
|
| 49 |
+
(self.output_dir / "images" / "train").mkdir(parents=True, exist_ok=True)
|
| 50 |
+
(self.output_dir / "images" / "val").mkdir(parents=True, exist_ok=True)
|
| 51 |
+
(self.output_dir / "images" / "test").mkdir(parents=True, exist_ok=True)
|
| 52 |
+
(self.output_dir / "arrays").mkdir(exist_ok=True)
|
| 53 |
+
|
| 54 |
+
logger.info(f"Output directory: {self.output_dir}")
|
| 55 |
+
|
| 56 |
+
def load_json_data(self, json_path):
|
| 57 |
+
"""Load gaze data from JSON file."""
|
| 58 |
+
logger.info(f"Loading data from {json_path}")
|
| 59 |
+
|
| 60 |
+
with open(json_path, 'r') as f:
|
| 61 |
+
data = json.load(f)
|
| 62 |
+
|
| 63 |
+
logger.info(f"Loaded {data.get('total_frames', 0)} frames")
|
| 64 |
+
logger.info(f"Screen dimensions: {data.get('screen_width')}x{data.get('screen_height')}")
|
| 65 |
+
|
| 66 |
+
return data
|
| 67 |
+
|
| 68 |
+
def process_multiple_files(self, data_folder):
|
| 69 |
+
"""Process multiple JSON files from a folder."""
|
| 70 |
+
data_folder = Path(data_folder)
|
| 71 |
+
json_files = list(data_folder.glob("*.json"))
|
| 72 |
+
|
| 73 |
+
if not json_files:
|
| 74 |
+
raise ValueError(f"No JSON files found in {data_folder}")
|
| 75 |
+
|
| 76 |
+
logger.info(f"Found {len(json_files)} JSON files")
|
| 77 |
+
|
| 78 |
+
all_data_points = []
|
| 79 |
+
metadata = None
|
| 80 |
+
|
| 81 |
+
for json_file in json_files:
|
| 82 |
+
logger.info(f"Processing {json_file.name}")
|
| 83 |
+
data = self.load_json_data(json_file)
|
| 84 |
+
|
| 85 |
+
if metadata is None:
|
| 86 |
+
metadata = {
|
| 87 |
+
'screen_width': data.get('screen_width'),
|
| 88 |
+
'screen_height': data.get('screen_height'),
|
| 89 |
+
'image_dimensions': '60x80 pixels',
|
| 90 |
+
'total_files': len(json_files)
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
all_data_points.extend(data.get('data_points', []))
|
| 94 |
+
|
| 95 |
+
return {'data_points': all_data_points, **metadata}
|
| 96 |
+
|
| 97 |
+
def decode_and_process_image(self, base64_data, index):
|
| 98 |
+
"""
|
| 99 |
+
Decode base64 image and return numpy array (no saving during processing).
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
base64_data: Base64 encoded image string
|
| 103 |
+
index: Frame index for logging
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
numpy array of the image
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
# Remove data URL prefix if present
|
| 110 |
+
if ',' in base64_data:
|
| 111 |
+
base64_data = base64_data.split(',')[1]
|
| 112 |
+
|
| 113 |
+
# Decode base64
|
| 114 |
+
image_bytes = base64.b64decode(base64_data)
|
| 115 |
+
|
| 116 |
+
# Convert to PIL Image
|
| 117 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 118 |
+
|
| 119 |
+
# Convert to RGB if needed
|
| 120 |
+
if image.mode != 'RGB':
|
| 121 |
+
image = image.convert('RGB')
|
| 122 |
+
|
| 123 |
+
# Resize to ensure consistent dimensions
|
| 124 |
+
image = image.resize((self.image_size[1], self.image_size[0])) # PIL uses (width, height)
|
| 125 |
+
|
| 126 |
+
# Convert to numpy array and normalize
|
| 127 |
+
image_array = np.array(image, dtype=np.float32) / 255.0
|
| 128 |
+
|
| 129 |
+
return image_array
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Error processing image {index}: {e}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def normalize_coordinates(self, x, y, screen_width, screen_height):
|
| 136 |
+
"""Normalize screen coordinates to [0, 1] range."""
|
| 137 |
+
norm_x = x / screen_width
|
| 138 |
+
norm_y = y / screen_height
|
| 139 |
+
|
| 140 |
+
# Clamp to [0, 1] range
|
| 141 |
+
norm_x = np.clip(norm_x, 0.0, 1.0)
|
| 142 |
+
norm_y = np.clip(norm_y, 0.0, 1.0)
|
| 143 |
+
|
| 144 |
+
return norm_x, norm_y
|
| 145 |
+
|
| 146 |
+
def convert_dataset(self, data):
|
| 147 |
+
"""Convert the gaze data to training format."""
|
| 148 |
+
logger.info("Converting dataset...")
|
| 149 |
+
|
| 150 |
+
images = []
|
| 151 |
+
gaze_points = []
|
| 152 |
+
timestamps = []
|
| 153 |
+
|
| 154 |
+
screen_width = data.get('screen_width', 1920)
|
| 155 |
+
screen_height = data.get('screen_height', 1080)
|
| 156 |
+
|
| 157 |
+
data_points = data.get('data_points', [])
|
| 158 |
+
|
| 159 |
+
for i, point in enumerate(data_points):
|
| 160 |
+
if i % 100 == 0:
|
| 161 |
+
logger.info(f"Processing frame {i}/{len(data_points)}")
|
| 162 |
+
|
| 163 |
+
# Process image
|
| 164 |
+
image_array = self.decode_and_process_image(
|
| 165 |
+
point.get('image_data', ''), i
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if image_array is not None:
|
| 169 |
+
# Normalize gaze coordinates
|
| 170 |
+
norm_x, norm_y = self.normalize_coordinates(
|
| 171 |
+
point.get('screen_x', 0),
|
| 172 |
+
point.get('screen_y', 0),
|
| 173 |
+
screen_width,
|
| 174 |
+
screen_height
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
images.append(image_array)
|
| 178 |
+
gaze_points.append([norm_x, norm_y])
|
| 179 |
+
timestamps.append(point.get('timestamp', 0))
|
| 180 |
+
|
| 181 |
+
# Convert to numpy arrays
|
| 182 |
+
images = np.array(images, dtype=np.float32)
|
| 183 |
+
gaze_points = np.array(gaze_points, dtype=np.float32)
|
| 184 |
+
timestamps = np.array(timestamps, dtype=np.float32)
|
| 185 |
+
|
| 186 |
+
logger.info(f"Processed {len(images)} valid frames")
|
| 187 |
+
logger.info(f"Image shape: {images.shape}")
|
| 188 |
+
logger.info(f"Gaze points shape: {gaze_points.shape}")
|
| 189 |
+
|
| 190 |
+
return images, gaze_points, timestamps
|
| 191 |
+
|
| 192 |
+
def split_dataset(self, images, gaze_points, timestamps):
|
| 193 |
+
"""Split dataset into train/validation/test sets."""
|
| 194 |
+
logger.info("Splitting dataset...")
|
| 195 |
+
|
| 196 |
+
# First split: separate test set
|
| 197 |
+
X_temp, X_test, y_temp, y_test, t_temp, t_test = train_test_split(
|
| 198 |
+
images, gaze_points, timestamps,
|
| 199 |
+
test_size=self.test_size,
|
| 200 |
+
random_state=42,
|
| 201 |
+
stratify=None
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Second split: separate train and validation from remaining data
|
| 205 |
+
val_size_adjusted = self.val_size / (1 - self.test_size)
|
| 206 |
+
X_train, X_val, y_train, y_val, t_train, t_val = train_test_split(
|
| 207 |
+
X_temp, y_temp, t_temp,
|
| 208 |
+
test_size=val_size_adjusted,
|
| 209 |
+
random_state=42,
|
| 210 |
+
stratify=None
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
logger.info(f"Train set: {len(X_train)} samples")
|
| 214 |
+
logger.info(f"Validation set: {len(X_val)} samples")
|
| 215 |
+
logger.info(f"Test set: {len(X_test)} samples")
|
| 216 |
+
|
| 217 |
+
return (X_train, y_train, t_train), \
|
| 218 |
+
(X_val, y_val, t_val), \
|
| 219 |
+
(X_test, y_test, t_test)
|
| 220 |
+
|
| 221 |
+
def save_images_by_split(self, train_data, val_data, test_data):
|
| 222 |
+
"""Save images organized by split."""
|
| 223 |
+
logger.info("Saving images by split...")
|
| 224 |
+
|
| 225 |
+
splits = [
|
| 226 |
+
(train_data, 'train'),
|
| 227 |
+
(val_data, 'val'),
|
| 228 |
+
(test_data, 'test')
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
for (X, y, t), split_name in splits:
|
| 232 |
+
split_dir = self.output_dir / "images" / split_name
|
| 233 |
+
|
| 234 |
+
for i, image_array in enumerate(X):
|
| 235 |
+
# Convert back to PIL Image (denormalize)
|
| 236 |
+
image_array_uint8 = (image_array * 255).astype(np.uint8)
|
| 237 |
+
image = Image.fromarray(image_array_uint8)
|
| 238 |
+
|
| 239 |
+
# Save image
|
| 240 |
+
image_filename = f"face_{i:06d}.jpg"
|
| 241 |
+
image_path = split_dir / image_filename
|
| 242 |
+
image.save(image_path, quality=85)
|
| 243 |
+
|
| 244 |
+
logger.info(f"Saved {len(X)} images to {split_name} set")
|
| 245 |
+
|
| 246 |
+
def save_numpy_arrays(self, train_data, val_data, test_data, metadata):
|
| 247 |
+
"""Save dataset as numpy arrays for fast loading."""
|
| 248 |
+
logger.info("Saving numpy arrays...")
|
| 249 |
+
|
| 250 |
+
arrays_dir = self.output_dir / "arrays"
|
| 251 |
+
|
| 252 |
+
# Save each split
|
| 253 |
+
splits = [
|
| 254 |
+
(train_data, 'train'),
|
| 255 |
+
(val_data, 'val'),
|
| 256 |
+
(test_data, 'test')
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
for (X, y, t), split_name in splits:
|
| 260 |
+
np.save(arrays_dir / f"{split_name}_images.npy", X)
|
| 261 |
+
np.save(arrays_dir / f"{split_name}_gaze.npy", y)
|
| 262 |
+
np.save(arrays_dir / f"{split_name}_timestamps.npy", t)
|
| 263 |
+
|
| 264 |
+
logger.info(f"Saved {split_name} set: {X.shape[0]} samples")
|
| 265 |
+
|
| 266 |
+
# Save metadata
|
| 267 |
+
with open(arrays_dir / "metadata.json", 'w') as f:
|
| 268 |
+
json.dump(metadata, f, indent=2)
|
| 269 |
+
|
| 270 |
+
def create_tensorflow_datasets(self, train_data, val_data, test_data, batch_size=32):
|
| 271 |
+
"""Create TensorFlow datasets."""
|
| 272 |
+
logger.info("Creating TensorFlow datasets...")
|
| 273 |
+
|
| 274 |
+
def create_tf_dataset(X, y, batch_size, shuffle=True):
|
| 275 |
+
dataset = tf.data.Dataset.from_tensor_slices((X, y))
|
| 276 |
+
if shuffle:
|
| 277 |
+
dataset = dataset.shuffle(buffer_size=1000)
|
| 278 |
+
dataset = dataset.batch(batch_size)
|
| 279 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 280 |
+
return dataset
|
| 281 |
+
|
| 282 |
+
train_dataset = create_tf_dataset(train_data[0], train_data[1], batch_size, shuffle=True)
|
| 283 |
+
val_dataset = create_tf_dataset(val_data[0], val_data[1], batch_size, shuffle=False)
|
| 284 |
+
test_dataset = create_tf_dataset(test_data[0], test_data[1], batch_size, shuffle=False)
|
| 285 |
+
|
| 286 |
+
# Save datasets
|
| 287 |
+
tf.data.Dataset.save(train_dataset, str(self.output_dir / "tf_datasets" / "train"))
|
| 288 |
+
tf.data.Dataset.save(val_dataset, str(self.output_dir / "tf_datasets" / "val"))
|
| 289 |
+
tf.data.Dataset.save(test_dataset, str(self.output_dir / "tf_datasets" / "test"))
|
| 290 |
+
|
| 291 |
+
logger.info("TensorFlow datasets saved")
|
| 292 |
+
|
| 293 |
+
return train_dataset, val_dataset, test_dataset
|
| 294 |
+
|
| 295 |
+
def visualize_samples(self, train_data, num_samples=9):
|
| 296 |
+
"""Create visualization of sample data."""
|
| 297 |
+
logger.info("Creating sample visualization...")
|
| 298 |
+
|
| 299 |
+
X_train, y_train = train_data[0], train_data[1]
|
| 300 |
+
|
| 301 |
+
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
|
| 302 |
+
fig.suptitle('Sample Training Data', fontsize=16)
|
| 303 |
+
|
| 304 |
+
indices = np.random.choice(len(X_train), num_samples, replace=False)
|
| 305 |
+
|
| 306 |
+
for i, idx in enumerate(indices):
|
| 307 |
+
row, col = i // 3, i % 3
|
| 308 |
+
ax = axes[row, col]
|
| 309 |
+
|
| 310 |
+
# Show image
|
| 311 |
+
ax.imshow(X_train[idx])
|
| 312 |
+
ax.set_title(f'Gaze: ({y_train[idx][0]:.3f}, {y_train[idx][1]:.3f})')
|
| 313 |
+
ax.axis('off')
|
| 314 |
+
|
| 315 |
+
# Add gaze point visualization
|
| 316 |
+
gaze_x = y_train[idx][0] * X_train[idx].shape[1]
|
| 317 |
+
gaze_y = y_train[idx][1] * X_train[idx].shape[0]
|
| 318 |
+
ax.plot(gaze_x, gaze_y, 'r+', markersize=10, markeredgewidth=2)
|
| 319 |
+
|
| 320 |
+
plt.tight_layout()
|
| 321 |
+
plt.savefig(self.output_dir / "sample_visualization.png", dpi=150, bbox_inches='tight')
|
| 322 |
+
plt.close()
|
| 323 |
+
|
| 324 |
+
logger.info("Sample visualization saved")
|
| 325 |
+
|
| 326 |
+
def generate_report(self, metadata, train_data, val_data, test_data):
|
| 327 |
+
"""Generate a summary report."""
|
| 328 |
+
X_train, y_train = train_data[0], train_data[1]
|
| 329 |
+
X_val, y_val = val_data[0], val_data[1]
|
| 330 |
+
X_test, y_test = test_data[0], test_data[1]
|
| 331 |
+
|
| 332 |
+
report = {
|
| 333 |
+
'dataset_info': {
|
| 334 |
+
'total_samples': len(X_train) + len(X_val) + len(X_test),
|
| 335 |
+
'train_samples': len(X_train),
|
| 336 |
+
'val_samples': len(X_val),
|
| 337 |
+
'test_samples': len(X_test),
|
| 338 |
+
'image_shape': X_train.shape[1:],
|
| 339 |
+
'gaze_range': {
|
| 340 |
+
'x_min': float(np.min(y_train[:, 0])),
|
| 341 |
+
'x_max': float(np.max(y_train[:, 0])),
|
| 342 |
+
'y_min': float(np.min(y_train[:, 1])),
|
| 343 |
+
'y_max': float(np.max(y_train[:, 1]))
|
| 344 |
+
}
|
| 345 |
+
},
|
| 346 |
+
'original_metadata': metadata,
|
| 347 |
+
'file_structure': {
|
| 348 |
+
'arrays/': 'Numpy arrays for fast loading',
|
| 349 |
+
'images/train/': 'Training images',
|
| 350 |
+
'images/val/': 'Validation images',
|
| 351 |
+
'images/test/': 'Test images',
|
| 352 |
+
'tf_datasets/': 'TensorFlow datasets'
|
| 353 |
+
}
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
with open(self.output_dir / "dataset_report.json", 'w') as f:
|
| 357 |
+
json.dump(report, f, indent=2)
|
| 358 |
+
|
| 359 |
+
logger.info("Dataset report generated")
|
| 360 |
+
return report
|
| 361 |
+
|
| 362 |
+
def main():
|
| 363 |
+
parser = argparse.ArgumentParser(description='Convert gaze data to TensorFlow dataset')
|
| 364 |
+
parser.add_argument('--input', required=True, help='Input JSON file or folder with JSON files')
|
| 365 |
+
parser.add_argument('--output', default='processed_dataset', help='Output directory')
|
| 366 |
+
parser.add_argument('--test-size', type=float, default=0.2, help='Test set proportion')
|
| 367 |
+
parser.add_argument('--val-size', type=float, default=0.1, help='Validation set proportion')
|
| 368 |
+
parser.add_argument('--batch-size', type=int, default=32, help='Batch size for TensorFlow datasets')
|
| 369 |
+
parser.add_argument('--visualize', action='store_true', help='Create sample visualizations')
|
| 370 |
+
|
| 371 |
+
args = parser.parse_args()
|
| 372 |
+
|
| 373 |
+
# Initialize converter
|
| 374 |
+
converter = GazeDatasetConverter(
|
| 375 |
+
output_dir=args.output,
|
| 376 |
+
test_size=args.test_size,
|
| 377 |
+
val_size=args.val_size
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Load data
|
| 381 |
+
if os.path.isfile(args.input):
|
| 382 |
+
data = converter.load_json_data(args.input)
|
| 383 |
+
elif os.path.isdir(args.input):
|
| 384 |
+
data = converter.process_multiple_files(args.input)
|
| 385 |
+
else:
|
| 386 |
+
raise ValueError(f"Input path {args.input} is neither a file nor directory")
|
| 387 |
+
|
| 388 |
+
# Convert dataset
|
| 389 |
+
images, gaze_points, timestamps = converter.convert_dataset(data)
|
| 390 |
+
|
| 391 |
+
if len(images) == 0:
|
| 392 |
+
logger.error("No valid data found!")
|
| 393 |
+
return
|
| 394 |
+
|
| 395 |
+
# Split dataset
|
| 396 |
+
train_data, val_data, test_data = converter.split_dataset(
|
| 397 |
+
images, gaze_points, timestamps
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Save everything
|
| 401 |
+
converter.save_images_by_split(train_data, val_data, test_data)
|
| 402 |
+
converter.save_numpy_arrays(train_data, val_data, test_data, data)
|
| 403 |
+
converter.create_tensorflow_datasets(train_data, val_data, test_data, args.batch_size)
|
| 404 |
+
|
| 405 |
+
# Optional visualization
|
| 406 |
+
if args.visualize:
|
| 407 |
+
converter.visualize_samples(train_data)
|
| 408 |
+
|
| 409 |
+
# Generate report
|
| 410 |
+
report = converter.generate_report(data, train_data, val_data, test_data)
|
| 411 |
+
|
| 412 |
+
logger.info("Dataset conversion complete!")
|
| 413 |
+
logger.info(f"Output directory: {converter.output_dir}")
|
| 414 |
+
logger.info(f"Total samples: {report['dataset_info']['total_samples']}")
|
| 415 |
+
|
| 416 |
+
if __name__ == "__main__":
|
| 417 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from tensorflow import keras
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
class GazeEstimator:
|
| 10 |
+
def __init__(self, model_path='best_gaze_model.keras', config_path='model_config.json',
|
| 11 |
+
input_shape=None, screen_width=None, screen_height=None):
|
| 12 |
+
"""
|
| 13 |
+
Initialize the real-time gaze estimator.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model_path: Path to the trained model (.keras file)
|
| 17 |
+
config_path: Path to model config JSON file
|
| 18 |
+
input_shape: Input shape expected by the model (overrides config)
|
| 19 |
+
screen_width: Width of the screen (overrides config)
|
| 20 |
+
screen_height: Height of the screen (overrides config)
|
| 21 |
+
"""
|
| 22 |
+
# Load config if available
|
| 23 |
+
if Path(config_path).exists():
|
| 24 |
+
with open(config_path, 'r') as f:
|
| 25 |
+
config = json.load(f)
|
| 26 |
+
self.input_shape = tuple(config.get('input_shape', [60, 80, 3]))
|
| 27 |
+
self.screen_width = config.get('screen_width', 1920)
|
| 28 |
+
self.screen_height = config.get('screen_height', 1080)
|
| 29 |
+
print(f"Loaded config: {config}")
|
| 30 |
+
else:
|
| 31 |
+
self.input_shape = (60, 80, 3)
|
| 32 |
+
self.screen_width = 1920
|
| 33 |
+
self.screen_height = 1080
|
| 34 |
+
|
| 35 |
+
# Override with provided values if any
|
| 36 |
+
if input_shape is not None:
|
| 37 |
+
self.input_shape = input_shape
|
| 38 |
+
if screen_width is not None:
|
| 39 |
+
self.screen_width = screen_width
|
| 40 |
+
if screen_height is not None:
|
| 41 |
+
self.screen_height = screen_height
|
| 42 |
+
|
| 43 |
+
print(f"Using input shape: {self.input_shape}")
|
| 44 |
+
print(f"Using screen dimensions: {self.screen_width}x{self.screen_height}")
|
| 45 |
+
|
| 46 |
+
# Load model
|
| 47 |
+
print(f"Loading model from {model_path}...")
|
| 48 |
+
self.model = keras.models.load_model(model_path)
|
| 49 |
+
print("Model loaded successfully!")
|
| 50 |
+
|
| 51 |
+
# Initialize camera
|
| 52 |
+
self.cap = cv2.VideoCapture(0)
|
| 53 |
+
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
| 54 |
+
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
| 55 |
+
|
| 56 |
+
# For FPS calculation
|
| 57 |
+
self.prev_time = 0
|
| 58 |
+
|
| 59 |
+
# Smoothing parameters
|
| 60 |
+
self.smooth_x = screen_width // 2
|
| 61 |
+
self.smooth_y = screen_height // 2
|
| 62 |
+
self.smoothing_factor = 0.3
|
| 63 |
+
|
| 64 |
+
def preprocess_frame(self, frame):
|
| 65 |
+
"""Preprocess frame for model input."""
|
| 66 |
+
# Resize to model input size
|
| 67 |
+
resized = cv2.resize(frame, (self.input_shape[1], self.input_shape[0]))
|
| 68 |
+
|
| 69 |
+
# Convert to RGB
|
| 70 |
+
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
| 71 |
+
|
| 72 |
+
# Normalize to [0, 1]
|
| 73 |
+
normalized = rgb.astype('float32') / 255.0
|
| 74 |
+
|
| 75 |
+
# Add batch dimension
|
| 76 |
+
batch = np.expand_dims(normalized, axis=0)
|
| 77 |
+
|
| 78 |
+
return batch, resized # Return resized image for visualization
|
| 79 |
+
|
| 80 |
+
def smooth_prediction(self, pred_x, pred_y):
|
| 81 |
+
"""Apply exponential smoothing to predictions."""
|
| 82 |
+
self.smooth_x = self.smoothing_factor * pred_x + (1 - self.smoothing_factor) * self.smooth_x
|
| 83 |
+
self.smooth_y = self.smoothing_factor * pred_y + (1 - self.smoothing_factor) * self.smooth_y
|
| 84 |
+
|
| 85 |
+
return int(self.smooth_x), int(self.smooth_y)
|
| 86 |
+
|
| 87 |
+
def run_inference(self):
|
| 88 |
+
"""Run real-time gaze estimation."""
|
| 89 |
+
# Create visualization window
|
| 90 |
+
cv2.namedWindow('Gaze Estimation', cv2.WINDOW_NORMAL)
|
| 91 |
+
cv2.resizeWindow('Gaze Estimation', 1400, 800)
|
| 92 |
+
|
| 93 |
+
print("\nStarting real-time gaze estimation...")
|
| 94 |
+
print("Controls:")
|
| 95 |
+
print(" 'q' - Quit")
|
| 96 |
+
print(" 's' - Toggle smoothing")
|
| 97 |
+
print(" 'r' - Reset smoothing")
|
| 98 |
+
print(" '+' - Increase smoothing factor")
|
| 99 |
+
print(" '-' - Decrease smoothing factor")
|
| 100 |
+
|
| 101 |
+
use_smoothing = True
|
| 102 |
+
|
| 103 |
+
while True:
|
| 104 |
+
# Capture frame
|
| 105 |
+
ret, frame = self.cap.read()
|
| 106 |
+
if not ret:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# Preprocess frame
|
| 110 |
+
input_batch, preprocessed_img = self.preprocess_frame(frame)
|
| 111 |
+
|
| 112 |
+
# Run inference
|
| 113 |
+
start_inference = time.time()
|
| 114 |
+
predictions = self.model.predict(input_batch, verbose=0)
|
| 115 |
+
inference_time = (time.time() - start_inference) * 1000 # ms
|
| 116 |
+
|
| 117 |
+
# Extract coordinates
|
| 118 |
+
norm_x, norm_y = predictions[0]
|
| 119 |
+
pred_x = int(norm_x * self.screen_width)
|
| 120 |
+
pred_y = int(norm_y * self.screen_height)
|
| 121 |
+
|
| 122 |
+
# Clamp predictions to screen bounds
|
| 123 |
+
pred_x = max(0, min(pred_x, self.screen_width - 1))
|
| 124 |
+
pred_y = max(0, min(pred_y, self.screen_height - 1))
|
| 125 |
+
|
| 126 |
+
# Apply smoothing if enabled
|
| 127 |
+
if use_smoothing:
|
| 128 |
+
gaze_x, gaze_y = self.smooth_prediction(pred_x, pred_y)
|
| 129 |
+
else:
|
| 130 |
+
gaze_x, gaze_y = pred_x, pred_y
|
| 131 |
+
|
| 132 |
+
# Create visualization
|
| 133 |
+
vis_frame = self.create_visualization(frame, preprocessed_img, gaze_x, gaze_y,
|
| 134 |
+
inference_time, use_smoothing)
|
| 135 |
+
|
| 136 |
+
# Calculate FPS
|
| 137 |
+
current_time = time.time()
|
| 138 |
+
fps = 1 / (current_time - self.prev_time) if self.prev_time > 0 else 0
|
| 139 |
+
self.prev_time = current_time
|
| 140 |
+
|
| 141 |
+
# Add FPS to visualization
|
| 142 |
+
cv2.putText(vis_frame, f"FPS: {fps:.1f}", (10, 30),
|
| 143 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
| 144 |
+
|
| 145 |
+
# Show frame
|
| 146 |
+
cv2.imshow('Gaze Estimation', vis_frame)
|
| 147 |
+
|
| 148 |
+
# Handle key presses
|
| 149 |
+
key = cv2.waitKey(1) & 0xFF
|
| 150 |
+
if key == ord('q'):
|
| 151 |
+
break
|
| 152 |
+
elif key == ord('s'):
|
| 153 |
+
use_smoothing = not use_smoothing
|
| 154 |
+
print(f"Smoothing: {'ON' if use_smoothing else 'OFF'}")
|
| 155 |
+
elif key == ord('r'):
|
| 156 |
+
self.smooth_x = self.screen_width // 2
|
| 157 |
+
self.smooth_y = self.screen_height // 2
|
| 158 |
+
print("Smoothing reset")
|
| 159 |
+
elif key == ord('+'):
|
| 160 |
+
self.smoothing_factor = min(0.9, self.smoothing_factor + 0.1)
|
| 161 |
+
print(f"Smoothing factor: {self.smoothing_factor:.1f}")
|
| 162 |
+
elif key == ord('-'):
|
| 163 |
+
self.smoothing_factor = max(0.1, self.smoothing_factor - 0.1)
|
| 164 |
+
print(f"Smoothing factor: {self.smoothing_factor:.1f}")
|
| 165 |
+
|
| 166 |
+
# Cleanup
|
| 167 |
+
self.cap.release()
|
| 168 |
+
cv2.destroyAllWindows()
|
| 169 |
+
|
| 170 |
+
def create_visualization(self, frame, preprocessed_img, gaze_x, gaze_y,
|
| 171 |
+
inference_time, use_smoothing):
|
| 172 |
+
"""Create visualization frame with gaze overlay and preprocessed image."""
|
| 173 |
+
# Create a larger canvas
|
| 174 |
+
canvas_height = 800
|
| 175 |
+
canvas_width = 1400
|
| 176 |
+
canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
|
| 177 |
+
|
| 178 |
+
# 1. Original webcam feed (top left)
|
| 179 |
+
cam_height = 360
|
| 180 |
+
cam_width = 480
|
| 181 |
+
resized_frame = cv2.resize(frame, (cam_width, cam_height))
|
| 182 |
+
canvas[20:20+cam_height, 20:20+cam_width] = resized_frame
|
| 183 |
+
cv2.putText(canvas, "Original Webcam", (20, 15),
|
| 184 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
|
| 185 |
+
|
| 186 |
+
# 2. Preprocessed image sent to network (bottom left)
|
| 187 |
+
# Scale up the preprocessed image for better visibility
|
| 188 |
+
prep_scale = 3
|
| 189 |
+
prep_height = self.input_shape[0] * prep_scale
|
| 190 |
+
prep_width = self.input_shape[1] * prep_scale
|
| 191 |
+
|
| 192 |
+
# Convert back to BGR for display
|
| 193 |
+
prep_display = cv2.cvtColor(preprocessed_img, cv2.COLOR_RGB2BGR)
|
| 194 |
+
prep_display = cv2.resize(prep_display, (prep_width, prep_height),
|
| 195 |
+
interpolation=cv2.INTER_NEAREST)
|
| 196 |
+
|
| 197 |
+
prep_y = cam_height + 60
|
| 198 |
+
canvas[prep_y:prep_y+prep_height, 20:20+prep_width] = prep_display
|
| 199 |
+
cv2.putText(canvas, f"Network Input ({self.input_shape[0]}x{self.input_shape[1]})",
|
| 200 |
+
(20, prep_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
|
| 201 |
+
|
| 202 |
+
# Draw border around preprocessed image
|
| 203 |
+
cv2.rectangle(canvas, (20, prep_y), (20+prep_width, prep_y+prep_height),
|
| 204 |
+
(100, 100, 100), 2)
|
| 205 |
+
|
| 206 |
+
# 3. Screen representation (right side)
|
| 207 |
+
screen_start_x = cam_width + 60
|
| 208 |
+
screen_start_y = 20
|
| 209 |
+
screen_vis_width = 800
|
| 210 |
+
screen_vis_height = 600
|
| 211 |
+
|
| 212 |
+
# Draw screen border
|
| 213 |
+
cv2.rectangle(canvas,
|
| 214 |
+
(screen_start_x, screen_start_y),
|
| 215 |
+
(screen_start_x + screen_vis_width, screen_start_y + screen_vis_height),
|
| 216 |
+
(255, 255, 255), 2)
|
| 217 |
+
|
| 218 |
+
# Draw grid on screen for reference
|
| 219 |
+
grid_color = (50, 50, 50)
|
| 220 |
+
for i in range(1, 4):
|
| 221 |
+
# Vertical lines
|
| 222 |
+
x = screen_start_x + (screen_vis_width * i) // 4
|
| 223 |
+
cv2.line(canvas, (x, screen_start_y), (x, screen_start_y + screen_vis_height),
|
| 224 |
+
grid_color, 1)
|
| 225 |
+
# Horizontal lines
|
| 226 |
+
y = screen_start_y + (screen_vis_height * i) // 4
|
| 227 |
+
cv2.line(canvas, (screen_start_x, y), (screen_start_x + screen_vis_width, y),
|
| 228 |
+
grid_color, 1)
|
| 229 |
+
|
| 230 |
+
# Scale gaze coordinates to visualization
|
| 231 |
+
vis_gaze_x = screen_start_x + int((gaze_x / self.screen_width) * screen_vis_width)
|
| 232 |
+
vis_gaze_y = screen_start_y + int((gaze_y / self.screen_height) * screen_vis_height)
|
| 233 |
+
|
| 234 |
+
# Draw gaze trail (if smoothing is on)
|
| 235 |
+
if use_smoothing:
|
| 236 |
+
# Draw a fading trail
|
| 237 |
+
trail_color = (0, 100, 0)
|
| 238 |
+
cv2.circle(canvas, (vis_gaze_x, vis_gaze_y), 20, trail_color, 1)
|
| 239 |
+
|
| 240 |
+
# Draw gaze point
|
| 241 |
+
cv2.circle(canvas, (vis_gaze_x, vis_gaze_y), 8, (0, 255, 0), -1)
|
| 242 |
+
cv2.circle(canvas, (vis_gaze_x, vis_gaze_y), 12, (0, 255, 0), 2)
|
| 243 |
+
|
| 244 |
+
# Draw crosshair
|
| 245 |
+
cv2.line(canvas, (vis_gaze_x - 20, vis_gaze_y), (vis_gaze_x + 20, vis_gaze_y),
|
| 246 |
+
(0, 255, 0), 1)
|
| 247 |
+
cv2.line(canvas, (vis_gaze_x, vis_gaze_y - 20), (vis_gaze_x, vis_gaze_y + 20),
|
| 248 |
+
(0, 255, 0), 1)
|
| 249 |
+
|
| 250 |
+
# Add title
|
| 251 |
+
cv2.putText(canvas, "Real-time Gaze Estimation",
|
| 252 |
+
(screen_start_x + 250, screen_start_y - 5),
|
| 253 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
| 254 |
+
|
| 255 |
+
# 4. Information panel (bottom right)
|
| 256 |
+
info_y = screen_start_y + screen_vis_height + 30
|
| 257 |
+
info_x = screen_start_x
|
| 258 |
+
|
| 259 |
+
# Background for info panel
|
| 260 |
+
cv2.rectangle(canvas, (info_x - 10, info_y - 10),
|
| 261 |
+
(info_x + 400, info_y + 120), (30, 30, 30), -1)
|
| 262 |
+
|
| 263 |
+
# Info text
|
| 264 |
+
cv2.putText(canvas, f"Gaze Position: ({gaze_x}, {gaze_y})",
|
| 265 |
+
(info_x, info_y + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
|
| 266 |
+
cv2.putText(canvas, f"Normalized: ({gaze_x/self.screen_width:.3f}, {gaze_y/self.screen_height:.3f})",
|
| 267 |
+
(info_x, info_y + 45), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
|
| 268 |
+
cv2.putText(canvas, f"Inference Time: {inference_time:.1f} ms",
|
| 269 |
+
(info_x, info_y + 70), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
|
| 270 |
+
cv2.putText(canvas, f"Smoothing: {'ON' if use_smoothing else 'OFF'} (factor: {self.smoothing_factor:.1f})",
|
| 271 |
+
(info_x, info_y + 95), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
| 272 |
+
(0, 255, 0) if use_smoothing else (100, 100, 100), 1)
|
| 273 |
+
|
| 274 |
+
return canvas
|
| 275 |
+
|
| 276 |
+
class TFLiteGazeEstimator(GazeEstimator):
|
| 277 |
+
def __init__(self, tflite_path='gaze_model_efficient.tflite', config_path='model_config.json',
|
| 278 |
+
input_shape=None, screen_width=None, screen_height=None):
|
| 279 |
+
"""
|
| 280 |
+
TFLite version for even faster inference.
|
| 281 |
+
"""
|
| 282 |
+
# Load config if available
|
| 283 |
+
if Path(config_path).exists():
|
| 284 |
+
with open(config_path, 'r') as f:
|
| 285 |
+
config = json.load(f)
|
| 286 |
+
self.input_shape = tuple(config.get('input_shape', [60, 80, 3]))
|
| 287 |
+
self.screen_width = config.get('screen_width', 1920)
|
| 288 |
+
self.screen_height = config.get('screen_height', 1080)
|
| 289 |
+
print(f"Loaded config: {config}")
|
| 290 |
+
else:
|
| 291 |
+
self.input_shape = (60, 80, 3)
|
| 292 |
+
self.screen_width = 1920
|
| 293 |
+
self.screen_height = 1080
|
| 294 |
+
|
| 295 |
+
# Override with provided values if any
|
| 296 |
+
if input_shape is not None:
|
| 297 |
+
self.input_shape = input_shape
|
| 298 |
+
if screen_width is not None:
|
| 299 |
+
self.screen_width = screen_width
|
| 300 |
+
if screen_height is not None:
|
| 301 |
+
self.screen_height = screen_height
|
| 302 |
+
|
| 303 |
+
print(f"Using input shape: {self.input_shape}")
|
| 304 |
+
print(f"Using screen dimensions: {self.screen_width}x{self.screen_height}")
|
| 305 |
+
|
| 306 |
+
# Load TFLite model
|
| 307 |
+
print(f"Loading TFLite model from {tflite_path}...")
|
| 308 |
+
self.interpreter = tf.lite.Interpreter(model_path=tflite_path)
|
| 309 |
+
self.interpreter.allocate_tensors()
|
| 310 |
+
|
| 311 |
+
# Get input and output details
|
| 312 |
+
self.input_details = self.interpreter.get_input_details()
|
| 313 |
+
self.output_details = self.interpreter.get_output_details()
|
| 314 |
+
|
| 315 |
+
print(f"Input details: {self.input_details}")
|
| 316 |
+
print(f"Output details: {self.output_details}")
|
| 317 |
+
|
| 318 |
+
# Initialize camera
|
| 319 |
+
self.cap = cv2.VideoCapture(0)
|
| 320 |
+
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
| 321 |
+
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
| 322 |
+
|
| 323 |
+
# For FPS calculation
|
| 324 |
+
self.prev_time = 0
|
| 325 |
+
|
| 326 |
+
# Smoothing parameters
|
| 327 |
+
self.smooth_x = screen_width // 2
|
| 328 |
+
self.smooth_y = screen_height // 2
|
| 329 |
+
self.smoothing_factor = 0.3
|
| 330 |
+
|
| 331 |
+
# Parent class methods
|
| 332 |
+
self.model = None # Dummy for compatibility
|
| 333 |
+
|
| 334 |
+
def predict(self, input_batch, verbose=0):
|
| 335 |
+
"""Run TFLite inference."""
|
| 336 |
+
self.interpreter.set_tensor(self.input_details[0]['index'], input_batch)
|
| 337 |
+
self.interpreter.invoke()
|
| 338 |
+
output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
|
| 339 |
+
return output_data
|
| 340 |
+
|
| 341 |
+
def test_camera():
|
| 342 |
+
"""Test if camera is working."""
|
| 343 |
+
print("Testing camera...")
|
| 344 |
+
cap = cv2.VideoCapture(0)
|
| 345 |
+
|
| 346 |
+
if not cap.isOpened():
|
| 347 |
+
print("Error: Could not open camera")
|
| 348 |
+
return False
|
| 349 |
+
|
| 350 |
+
ret, frame = cap.read()
|
| 351 |
+
if not ret:
|
| 352 |
+
print("Error: Could not read from camera")
|
| 353 |
+
cap.release()
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
print(f"Camera working! Frame shape: {frame.shape}")
|
| 357 |
+
cap.release()
|
| 358 |
+
return True
|
| 359 |
+
|
| 360 |
+
def main():
|
| 361 |
+
# Test camera first
|
| 362 |
+
if not test_camera():
|
| 363 |
+
return
|
| 364 |
+
|
| 365 |
+
# Check which model files are available
|
| 366 |
+
keras_model = Path('best_gaze_model.keras')
|
| 367 |
+
tflite_model = Path('gaze_model_efficient.tflite')
|
| 368 |
+
|
| 369 |
+
use_tflite = False
|
| 370 |
+
|
| 371 |
+
if tflite_model.exists() and input("Use TFLite model for faster inference? (y/n): ").lower() == 'y':
|
| 372 |
+
use_tflite = True
|
| 373 |
+
elif not keras_model.exists():
|
| 374 |
+
print(f"Error: Model file {keras_model} not found!")
|
| 375 |
+
return
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
if use_tflite:
|
| 379 |
+
print("\nUsing TFLite model...")
|
| 380 |
+
estimator = TFLiteGazeEstimator(
|
| 381 |
+
tflite_path=str(tflite_model),
|
| 382 |
+
config_path='model_config.json'
|
| 383 |
+
)
|
| 384 |
+
# Override predict method for compatibility
|
| 385 |
+
estimator.model = estimator # Dummy reference
|
| 386 |
+
estimator.model.predict = estimator.predict
|
| 387 |
+
else:
|
| 388 |
+
print("\nUsing Keras model...")
|
| 389 |
+
estimator = GazeEstimator(
|
| 390 |
+
model_path=str(keras_model),
|
| 391 |
+
config_path='model_config.json'
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Run real-time inference
|
| 395 |
+
estimator.run_inference()
|
| 396 |
+
|
| 397 |
+
except Exception as e:
|
| 398 |
+
print(f"Error: {e}")
|
| 399 |
+
import traceback
|
| 400 |
+
traceback.print_exc()
|
| 401 |
+
|
| 402 |
+
if __name__ == "__main__":
|
| 403 |
+
main()
|
readme.txt
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gaze Estimation Setup Instructions for Ubuntu 22.04
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
- Conda installed on Ubuntu 22.04
|
| 5 |
+
- Webcam connected to your system
|
| 6 |
+
- NVIDIA GPU (optional, for faster training)
|
| 7 |
+
|
| 8 |
+
## Step 1, Virtual env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Step 1: Create Conda Environment
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# Create a new conda environment with Python 3.9
|
| 16 |
+
conda create -n gaze_estimation python=3.9 -y
|
| 17 |
+
|
| 18 |
+
# Activate the environment
|
| 19 |
+
conda activate gaze_estimation
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Step 2: Install Core Dependencies
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
# Install TensorFlow (CPU version)
|
| 26 |
+
pip install tensorflow==2.13.0
|
| 27 |
+
|
| 28 |
+
# OR for GPU support (if you have NVIDIA GPU with CUDA)
|
| 29 |
+
# pip install tensorflow[and-cuda]==2.13.0
|
| 30 |
+
|
| 31 |
+
# Install OpenCV
|
| 32 |
+
pip install opencv-python==4.8.1.78
|
| 33 |
+
|
| 34 |
+
# Install additional required packages
|
| 35 |
+
pip install numpy==1.24.3
|
| 36 |
+
pip install matplotlib==3.7.1
|
| 37 |
+
pip install scikit-learn==1.3.0
|
| 38 |
+
pip install pillow==10.0.0
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Step 3: Install System Dependencies for OpenCV
|
| 42 |
+
|
| 43 |
+
OpenCV might need some system libraries on Ubuntu 22.04:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Update package list
|
| 47 |
+
sudo apt update
|
| 48 |
+
|
| 49 |
+
# Install required system libraries
|
| 50 |
+
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1
|
| 51 |
+
|
| 52 |
+
# Install video codecs and camera support
|
| 53 |
+
sudo apt install -y libgstreamer1.0-0 libgstreamer-plugins-base1.0-0 v4l-utils
|
| 54 |
+
|
| 55 |
+
# Install additional GUI libraries for OpenCV windows
|
| 56 |
+
sudo apt install -y libgtk-3-0 libgtk-3-dev
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Step 4: Verify Installation
|
| 60 |
+
|
| 61 |
+
Create a test script `test_installation.py`:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
import cv2
|
| 65 |
+
import tensorflow as tf
|
| 66 |
+
import numpy as np
|
| 67 |
+
import matplotlib.pyplot as plt
|
| 68 |
+
from sklearn.model_selection import train_test_split
|
| 69 |
+
|
| 70 |
+
print("OpenCV version:", cv2.__version__)
|
| 71 |
+
print("TensorFlow version:", tf.__version__)
|
| 72 |
+
print("NumPy version:", np.__version__)
|
| 73 |
+
|
| 74 |
+
# Test camera
|
| 75 |
+
cap = cv2.VideoCapture(0)
|
| 76 |
+
if cap.isOpened():
|
| 77 |
+
print("Camera is accessible")
|
| 78 |
+
ret, frame = cap.read()
|
| 79 |
+
if ret:
|
| 80 |
+
print("Camera capture successful")
|
| 81 |
+
cap.release()
|
| 82 |
+
else:
|
| 83 |
+
print("Camera not found")
|
| 84 |
+
|
| 85 |
+
# Test TensorFlow
|
| 86 |
+
print("TensorFlow GPU available:", tf.config.list_physical_devices('GPU'))
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Run the test:
|
| 90 |
+
```bash
|
| 91 |
+
python test_installation.py
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Step 5: Camera Permissions
|
| 95 |
+
|
| 96 |
+
If you have camera permission issues:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
# Add your user to the video group
|
| 100 |
+
sudo usermod -a -G video $USER
|
| 101 |
+
|
| 102 |
+
# Log out and log back in, or run:
|
| 103 |
+
newgrp video
|
| 104 |
+
|
| 105 |
+
# Check camera devices
|
| 106 |
+
ls -la /dev/video*
|
| 107 |
+
|
| 108 |
+
# Test camera with v4l2
|
| 109 |
+
v4l2-ctl --list-devices
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Step 6: Download the Scripts
|
| 113 |
+
|
| 114 |
+
Save the three Python scripts I provided:
|
| 115 |
+
1. `gaze_data_collection.py` - For collecting training data
|
| 116 |
+
2. `gaze_training.py` - For training the model
|
| 117 |
+
3. `gaze_inference.py` - For real-time inference
|
| 118 |
+
|
| 119 |
+
## Step 7: Configure Display Settings
|
| 120 |
+
|
| 121 |
+
For the data collection script to work properly in fullscreen:
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
# Check your display resolution
|
| 125 |
+
xrandr | grep current
|
| 126 |
+
|
| 127 |
+
# You might need to allow OpenCV to create fullscreen windows
|
| 128 |
+
# If using Wayland, you may need to switch to X11:
|
| 129 |
+
echo $XDG_SESSION_TYPE # Check if using Wayland or X11
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Step 8: Environment Variables (Optional)
|
| 133 |
+
|
| 134 |
+
Create a `.env` file or export these for better performance:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
# Limit TensorFlow GPU memory growth
|
| 138 |
+
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
| 139 |
+
|
| 140 |
+
# Set number of threads for better CPU performance
|
| 141 |
+
export OMP_NUM_THREADS=4
|
| 142 |
+
export TF_NUM_INTEROP_THREADS=4
|
| 143 |
+
export TF_NUM_INTRAOP_THREADS=4
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## Complete Installation Script
|
| 147 |
+
|
| 148 |
+
Here's a complete script to set up everything:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
#!/bin/bash
|
| 152 |
+
|
| 153 |
+
# Create and activate conda environment
|
| 154 |
+
conda create -n gaze_estimation python=3.9 -y
|
| 155 |
+
source $(conda info --base)/etc/profile.d/conda.sh
|
| 156 |
+
conda activate gaze_estimation
|
| 157 |
+
|
| 158 |
+
# Install Python packages
|
| 159 |
+
pip install tensorflow==2.13.0
|
| 160 |
+
pip install opencv-python==4.8.1.78
|
| 161 |
+
pip install numpy==1.24.3
|
| 162 |
+
pip install matplotlib==3.7.1
|
| 163 |
+
pip install scikit-learn==1.3.0
|
| 164 |
+
pip install pillow==10.0.0
|
| 165 |
+
|
| 166 |
+
# Install system dependencies
|
| 167 |
+
sudo apt update
|
| 168 |
+
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1
|
| 169 |
+
sudo apt install -y libgstreamer1.0-0 libgstreamer-plugins-base1.0-0 v4l-utils
|
| 170 |
+
sudo apt install -y libgtk-3-0 libgtk-3-dev
|
| 171 |
+
|
| 172 |
+
# Add user to video group
|
| 173 |
+
sudo usermod -a -G video $USER
|
| 174 |
+
|
| 175 |
+
echo "Setup complete! Please log out and log back in for video group changes to take effect."
|
| 176 |
+
echo "Then activate the environment with: conda activate gaze_estimation"
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Save this as `setup_gaze_env.sh` and run:
|
| 180 |
+
```bash
|
| 181 |
+
chmod +x setup_gaze_env.sh
|
| 182 |
+
./setup_gaze_env.sh
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## Troubleshooting
|
| 186 |
+
|
| 187 |
+
### Camera Issues
|
| 188 |
+
If the camera isn't detected:
|
| 189 |
+
```bash
|
| 190 |
+
# Check if camera is detected by system
|
| 191 |
+
ls -la /dev/video*
|
| 192 |
+
v4l2-ctl --list-devices
|
| 193 |
+
|
| 194 |
+
# Test with simple capture
|
| 195 |
+
python -c "import cv2; cap = cv2.VideoCapture(0); print('Camera opened:', cap.isOpened()); cap.release()"
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
### OpenCV Window Issues
|
| 199 |
+
If OpenCV windows don't appear:
|
| 200 |
+
```bash
|
| 201 |
+
# Install additional backends
|
| 202 |
+
sudo apt install -y python3-opencv libopencv-dev
|
| 203 |
+
|
| 204 |
+
# For Wayland compatibility issues, force X11
|
| 205 |
+
export GDK_BACKEND=x11
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
### TensorFlow Issues
|
| 209 |
+
If TensorFlow has compatibility issues:
|
| 210 |
+
```bash
|
| 211 |
+
# Check CUDA compatibility (for GPU)
|
| 212 |
+
nvidia-smi
|
| 213 |
+
|
| 214 |
+
# Install specific CUDA version if needed
|
| 215 |
+
conda install -c conda-forge cudatoolkit=11.8 cudnn=8.6
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
### Permission Denied Errors
|
| 219 |
+
```bash
|
| 220 |
+
# For camera access
|
| 221 |
+
sudo chmod 666 /dev/video0
|
| 222 |
+
|
| 223 |
+
# For display access
|
| 224 |
+
xhost +local:
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
## Running the System
|
| 228 |
+
|
| 229 |
+
Once everything is installed:
|
| 230 |
+
|
| 231 |
+
1. **Collect data:**
|
| 232 |
+
```bash
|
| 233 |
+
conda activate gaze_estimation
|
| 234 |
+
python gaze_data_collection.py
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
2. **Train model:**
|
| 238 |
+
```bash
|
| 239 |
+
python gaze_training.py
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
3. **Run inference:**
|
| 243 |
+
```bash
|
| 244 |
+
python gaze_inference.py
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
## Optional: Jupyter Notebook Support
|
| 248 |
+
|
| 249 |
+
If you want to experiment with Jupyter notebooks:
|
| 250 |
+
```bash
|
| 251 |
+
pip install jupyter ipykernel
|
| 252 |
+
python -m ipykernel install --user --name gaze_estimation --display-name "Gaze Estimation"
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
## Clean Environment File
|
| 256 |
+
|
| 257 |
+
Create `requirements.txt` for easy reproduction:
|
| 258 |
+
```
|
| 259 |
+
tensorflow==2.13.0
|
| 260 |
+
opencv-python==4.8.1.78
|
| 261 |
+
numpy==1.24.3
|
| 262 |
+
matplotlib==3.7.1
|
| 263 |
+
scikit-learn==1.3.0
|
| 264 |
+
pillow==10.0.0
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
Then others can install with:
|
| 268 |
+
```bash
|
| 269 |
+
conda create -n gaze_estimation python=3.9 -y
|
| 270 |
+
conda activate gaze_estimation
|
| 271 |
+
pip install -r requirements.txt
|
| 272 |
+
```
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow==2.18.1
|
| 2 |
+
opencv-python==4.10.0.84
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
matplotlib==3.9.2
|
| 5 |
+
scikit-learn==1.5.2
|
| 6 |
+
pillow==10.4.0
|