File size: 5,932 Bytes
7ea1851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3
"""
Re-evaluate all stored face detections against the current person reference set.

This is the backfill step that makes newly curated references pay off across
already processed videos without rescanning the media files.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any

import numpy as np


ROOT_DIR = Path(__file__).resolve().parents[1]
BACKEND_DIR = ROOT_DIR / "backend"
if str(BACKEND_DIR) not in sys.path:
    sys.path.insert(0, str(BACKEND_DIR))

from face_search import MATCH_THRESHOLD, compare_embeddings, get_face_search_instance  # noqa: E402
from utils import log_message  # noqa: E402


def parse_embedding(raw_value: Any) -> np.ndarray | None:
    if raw_value is None:
        return None
    if isinstance(raw_value, np.ndarray):
        return raw_value
    if isinstance(raw_value, memoryview):
        raw_value = raw_value.tobytes()
    if isinstance(raw_value, (bytes, bytearray)):
        return np.frombuffer(raw_value, dtype=np.float32)
    if isinstance(raw_value, str):
        return np.array(json.loads(raw_value), dtype=np.float32)
    if isinstance(raw_value, list):
        return np.array(raw_value, dtype=np.float32)
    return None


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Rematch stored face detections against current person references.")
    parser.add_argument("--threshold", type=float, default=MATCH_THRESHOLD)
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument(
        "--source-type",
        dest="source_types",
        action="append",
        help="Limit person reference averages to these source types. Repeat to allow multiple.",
    )
    parser.add_argument(
        "--report",
        type=Path,
        default=ROOT_DIR / "docs" / "reports" / "face-rematch-report-v1.json",
    )
    return parser


def get_filtered_person_average_embedding(face_search, person_id: int, source_types: list[str] | None) -> np.ndarray | None:
    conn = face_search._get_db_connection()
    try:
        if source_types:
            placeholders = ",".join("?" for _ in source_types)
            rows = conn.execute(
                f"""
                SELECT embedding_json
                FROM person_embeddings
                WHERE person_id = ? AND source_type IN ({placeholders})
                """,
                [person_id, *source_types],
            ).fetchall()
        else:
            rows = conn.execute(
                """
                SELECT embedding_json
                FROM person_embeddings
                WHERE person_id = ?
                """,
                (person_id,),
            ).fetchall()
    finally:
        conn.close()

    embeddings = [parse_embedding(row[0]) for row in rows]
    embeddings = [embedding for embedding in embeddings if embedding is not None]
    if not embeddings:
        return None

    avg_embedding = np.mean(embeddings, axis=0)
    norm = np.linalg.norm(avg_embedding)
    if norm == 0:
        return avg_embedding
    return avg_embedding / norm


def main() -> int:
    args = build_parser().parse_args()
    face_search = get_face_search_instance()
    people = face_search.list_persons()

    person_averages: dict[int, np.ndarray] = {}
    for person in people:
        avg = get_filtered_person_average_embedding(face_search, person["id"], args.source_types)
        if avg is not None:
            person_averages[person["id"]] = avg

    conn = face_search._get_db_connection()
    rows = conn.execute(
        """
        SELECT natural_key, frame_number, face_index, person_id, embedding
        FROM face_embeddings
        ORDER BY natural_key, frame_number, face_index
        """
    ).fetchall()
    conn.close()

    updates: list[dict[str, Any]] = []
    unchanged = 0

    for natural_key, frame_number, face_index, current_person_id, embedding_raw in rows:
        embedding = parse_embedding(embedding_raw)
        if embedding is None:
            unchanged += 1
            continue

        best_person_id = 0
        best_similarity = 0.0
        for person_id, avg in person_averages.items():
            similarity = compare_embeddings(embedding, avg)
            if similarity > best_similarity and similarity >= args.threshold:
                best_similarity = similarity
                best_person_id = person_id

        current_person_value = int(current_person_id or 0)
        if best_person_id == current_person_value:
            unchanged += 1
            continue

        update = {
            "natural_key": natural_key,
            "frame_number": int(frame_number),
            "face_index": int(face_index),
            "old_person_id": current_person_value,
            "new_person_id": best_person_id,
            "similarity": round(float(best_similarity), 4),
        }
        updates.append(update)
        if not args.dry_run:
            face_search.reassign_face(natural_key, int(frame_number), int(face_index), best_person_id or None)

    report = {
        "threshold": args.threshold,
        "dry_run": args.dry_run,
        "source_types": args.source_types or [],
        "persons_considered": len(person_averages),
        "total_faces": len(rows),
        "unchanged": unchanged,
        "updated": len(updates),
        "to_identified": sum(1 for item in updates if item["new_person_id"] > 0),
        "to_unidentified": sum(1 for item in updates if item["new_person_id"] == 0),
        "updates": updates,
    }
    args.report.parent.mkdir(parents=True, exist_ok=True)
    args.report.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8")
    log_message(f"Rematch finished: {len(updates)} updates, {unchanged} unchanged")
    print(f"Wrote report to {args.report}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())