File size: 2,903 Bytes
1dc2504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

import os
import sys
from pathlib import Path

import numpy as np
import torch

# Allow imports from project root
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.data.eye_extract import extract_sequences_from_video_bytes
from src.models.lrcn_vit import LRCNViT
from src.utils import load_yaml

_BLINK_THRESHOLD = 0.2
_BLINK_MIN_FRAMES = 2


def _count_blink_rate(ear_sequences: list[np.ndarray], fps: float = 25.0) -> float:
    """Estimate blinks per second from EAR signal across all sequences."""
    if not ear_sequences:
        return 0.0
    ear = np.concatenate(ear_sequences)
    blink = ear < _BLINK_THRESHOLD
    transitions = np.diff(blink.astype(int))
    blink_starts = int((transitions == 1).sum())
    total_seconds = len(ear) / fps
    return round(blink_starts / max(total_seconds, 1e-6), 2)


def load_model(checkpoint_path: str, config_path: str, device: str = "cpu") -> LRCNViT:
    cfg = load_yaml(config_path)
    model = LRCNViT(
        backbone_name=cfg["model"]["backbone"],
        backbone_pretrained=False,
        lstm_hidden=cfg["model"]["lstm_hidden"],
        lstm_layers=cfg["model"]["lstm_layers"],
        dropout=cfg["model"]["dropout"],
        num_classes=cfg["model"]["num_classes"],
        use_blink_head=cfg["model"].get("use_blink_head", True),
        image_size=cfg["data"]["image_size"],
    )
    state = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(state)
    model.eval()
    return model.to(device)


@torch.no_grad()
def predict_video(
    video_bytes: bytes,
    model: LRCNViT,
    seq_len: int = 8,
    device: str = "cpu",
) -> dict:
    sequences = extract_sequences_from_video_bytes(video_bytes, label=0, video_id="upload", seq_len=seq_len)

    if not sequences:
        return {
            "label": "UNKNOWN",
            "confidence": 0.0,
            "blink_rate": 0.0,
            "frame_scores": [],
            "attention_map_url": None,
        }

    frame_scores: list[float] = []
    ear_arrays: list[np.ndarray] = []

    for seq in sequences:
        frames = torch.tensor(
            seq["frames"].astype(np.float32) / 255.0
        ).permute(0, 3, 1, 2).unsqueeze(0).to(device)   # 1,T,C,H,W
        ear = torch.tensor(seq["ear"], dtype=torch.float32).unsqueeze(0).to(device)

        logits, _ = model(frames, ear)
        prob_fake = float(torch.softmax(logits, dim=1)[0, 1].item())
        frame_scores.append(prob_fake)
        ear_arrays.append(seq["ear"])

    confidence = float(np.mean(frame_scores))
    label = "FAKE" if confidence >= 0.5 else "REAL"
    blink_rate = _count_blink_rate(ear_arrays)

    return {
        "label": label,
        "confidence": round(confidence, 4),
        "blink_rate": blink_rate,
        "frame_scores": [round(s, 4) for s in frame_scores],
        "attention_map_url": None,
    }