Spaces:
Sleeping
Sleeping
File size: 1,524 Bytes
1dc2504 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | from __future__ import annotations
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class EyeSequenceDataset(Dataset):
def __init__(self, metadata_csv: str, split: str) -> None:
self.samples: List[Dict[str, str]] = []
df = pd.read_csv(metadata_csv)
df = df[df["split"] == split]
for row in df.to_dict(orient="records"):
if "npz_path" in row:
self.samples.append(
{"path": str(row["npz_path"]), "label": int(row["label"])}
)
continue
# Legacy layout from extract_eye_sequences.py
seq_dir = Path(row["sequence_dir"])
for npz in sorted(seq_dir.glob("*.npz")):
self.samples.append({"path": str(npz), "label": int(row["label"])})
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
sample = self.samples[idx]
obj = np.load(sample["path"])
frames = obj["frames"].astype(np.float32) / 255.0
ear_key = "ear" if "ear" in obj else "blink"
ear = obj[ear_key].astype(np.float32)
# T,H,W,C -> T,C,H,W
frames = np.transpose(frames, (0, 3, 1, 2))
return {
"frames": torch.tensor(frames),
"ear": torch.tensor(ear),
"blink": torch.tensor(ear),
"label": torch.tensor(sample["label"], dtype=torch.long),
}
|