Spaces:
Sleeping
Sleeping
File size: 4,690 Bytes
2eba0cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # 9-point gaze calibration for L2CS-Net
# Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
# Centre point is the bias reference (subtracted from all readings).
import numpy as np
from dataclasses import dataclass, field
# 3x3 grid, centre first (bias ref), then row by row
DEFAULT_TARGETS = [
(0.5, 0.5),
(0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
(0.15, 0.50), (0.85, 0.50),
(0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
]
@dataclass
class _PointSamples:
target_x: float
target_y: float
yaws: list = field(default_factory=list)
pitches: list = field(default_factory=list)
def _iqr_filter(values):
if len(values) < 4:
return values
arr = np.array(values)
q1, q3 = np.percentile(arr, [25, 75])
iqr = q3 - q1
lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
return arr[(arr >= lo) & (arr <= hi)].tolist()
class GazeCalibration:
def __init__(self, targets=None):
self._targets = targets or list(DEFAULT_TARGETS)
self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
self._current_idx = 0
self._fitted = False
self._W = None # (6, 2) polynomial weights
self._yaw_bias = 0.0
self._pitch_bias = 0.0
@property
def num_points(self):
return len(self._targets)
@property
def current_index(self):
return self._current_idx
@property
def current_target(self):
if self._current_idx < len(self._targets):
return self._targets[self._current_idx]
return self._targets[-1]
@property
def is_complete(self):
return self._current_idx >= len(self._targets)
@property
def is_fitted(self):
return self._fitted
def collect_sample(self, yaw_rad, pitch_rad):
if self._current_idx >= len(self._points):
return
pt = self._points[self._current_idx]
pt.yaws.append(float(yaw_rad))
pt.pitches.append(float(pitch_rad))
def advance(self):
self._current_idx += 1
return self._current_idx < len(self._targets)
@staticmethod
def _poly_features(yaw, pitch):
# [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
dtype=np.float64)
def fit(self):
# bias from centre point (index 0)
center = self._points[0]
center_yaws = _iqr_filter(center.yaws)
center_pitches = _iqr_filter(center.pitches)
if len(center_yaws) < 2 or len(center_pitches) < 2:
return False
self._yaw_bias = float(np.median(center_yaws))
self._pitch_bias = float(np.median(center_pitches))
rows_A, rows_B = [], []
for pt in self._points:
clean_yaws = _iqr_filter(pt.yaws)
clean_pitches = _iqr_filter(pt.pitches)
if len(clean_yaws) < 2 or len(clean_pitches) < 2:
continue
med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
rows_A.append(self._poly_features(med_yaw, med_pitch))
rows_B.append([pt.target_x, pt.target_y])
if len(rows_A) < 5:
return False
A = np.array(rows_A, dtype=np.float64)
B = np.array(rows_B, dtype=np.float64)
try:
W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
self._W = W
self._fitted = True
return True
except np.linalg.LinAlgError:
return False
def predict(self, yaw_rad, pitch_rad):
if not self._fitted or self._W is None:
return 0.5, 0.5
feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
xy = feat @ self._W
return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
def to_dict(self):
return {
"targets": self._targets,
"fitted": self._fitted,
"current_index": self._current_idx,
"W": self._W.tolist() if self._W is not None else None,
"yaw_bias": self._yaw_bias,
"pitch_bias": self._pitch_bias,
}
@classmethod
def from_dict(cls, d):
cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
cal._fitted = d.get("fitted", False)
cal._current_idx = d.get("current_index", 0)
cal._yaw_bias = d.get("yaw_bias", 0.0)
cal._pitch_bias = d.get("pitch_bias", 0.0)
w = d.get("W")
if w is not None:
cal._W = np.array(w, dtype=np.float64)
return cal
|