File size: 4,801 Bytes
850c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from pathlib import Path
import sys

from qdrant_client import models

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from app.config import get_settings  # noqa: E402
from app.embedder import ImageEmbedder  # noqa: E402
from app.manifest import Manifest  # noqa: E402
from app.qdrant_store import QdrantStore, point_id  # noqa: E402


def payload_base(manifest: Manifest, scenario: dict, item: dict, crop: str, memory_type: str) -> dict:
    camera = next(cam for cam in manifest.cameras(scenario["id"]) if cam["id"] == item["camera_id"])
    return {
        "scenario": scenario["id"],
        "camera_id": item["camera_id"],
        "zone": item.get("zone", camera["zone"]),
        "timestamp": int(item.get("timestamp", 1716900000)),
        "frame_url": manifest.asset_url(item.get("frame", camera.get("incident_frame", camera["baseline_frame"]))),
        "crop_url": manifest.asset_url(crop),
        "bbox": item.get("bbox", [0, 0, 0, 0]),
        "asset_id": item["asset_id"],
        "memory_type": memory_type,
        "is_baseline": bool(item.get("is_baseline", False)),
        "is_incident": bool(item.get("is_incident", False)),
        "object_label": item.get("object_label", "region"),
        "track_id": item.get("track_id"),
        "region_id": item.get("region_id"),
        "region_label": item.get("region_label"),
        "alarm_enabled": item.get("alarm_enabled"),
    }


def region_points(manifest: Manifest, embedder: ImageEmbedder, scenario: dict):
    points = []
    for region in manifest.regions(scenario["id"]):
        camera = next(cam for cam in manifest.cameras(scenario["id"]) if cam["id"] == region["camera_id"])
        for state, crop, baseline in [
            ("baseline", region["baseline_crop"], True),
            ("incident", region["incident_crop"], False),
        ]:
            asset_id = f"{region['region_id']}_{state}"
            payload = payload_base(
                manifest,
                scenario,
                {
                    **region,
                    "asset_id": asset_id,
                    "frame": camera[f"{state}_frame"],
                    "is_baseline": baseline,
                    "is_incident": not baseline,
                    "object_label": region.get(f"{state}_label", region["region_id"]),
                },
                crop,
                f"region_{state}",
            )
            vector = embedder.embed_path(manifest.asset_path(crop))
            points.append(models.PointStruct(id=point_id(f"{scenario['id']}:{asset_id}"), vector=vector, payload=payload))
            if baseline:
                for index, variant_crop in enumerate(baseline_variants(manifest, crop, region.get("variant_prefix")), start=1):
                    variant_id = f"{asset_id}_v{index}"
                    variant_payload = {
                        **payload,
                        "asset_id": variant_id,
                        "crop_url": manifest.asset_url(variant_crop),
                        "memory_variant": index,
                    }
                    vector = embedder.embed_path(manifest.asset_path(variant_crop))
                    points.append(models.PointStruct(id=point_id(f"{scenario['id']}:{variant_id}"), vector=vector, payload=variant_payload))
    return points


def baseline_variants(manifest: Manifest, crop: str, variant_prefix: str | None = None) -> list[str]:
    path = Path(crop)
    stem = variant_prefix or path.stem
    variant_dir = manifest.settings.resolved_asset_root / path.parent / "variants"
    if not variant_dir.exists():
        return []
    return [
        f"{path.parent.as_posix()}/variants/{variant.name}"
        for variant in sorted(variant_dir.glob(f"{stem}_v*.jpg"))
    ]


def object_points(manifest: Manifest, embedder: ImageEmbedder, scenario: dict):
    points = []
    for item in manifest.objects(scenario["id"]):
        payload = payload_base(manifest, scenario, item, item["crop"], item["memory_type"])
        vector = embedder.embed_path(manifest.asset_path(item["crop"]))
        points.append(models.PointStruct(id=point_id(f"{scenario['id']}:{item['asset_id']}"), vector=vector, payload=payload))
    return points


def main() -> int:
    settings = get_settings()
    manifest = Manifest(settings)
    embedder = ImageEmbedder(settings)
    store = QdrantStore(settings)
    store.ensure_collection(recreate=True)
    points = []
    for scenario in manifest.scenarios:
        points.extend(region_points(manifest, embedder, scenario))
        points.extend(object_points(manifest, embedder, scenario))
    store.upsert(points)
    print(f"Seeded {len(points)} points into {settings.collection_name} using {embedder.mode}.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())