File size: 2,469 Bytes
06fed2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Backfill embeddings for memories created before Phase 2b.

Walks all memories with `embedding IS NULL`, batch-encodes their
(narrative_text | triggers | relevance_tags) input, and writes the
vectors back. Safe to run multiple times — already-embedded rows are
skipped.

Usage:
    python -m scripts.backfill_embeddings
    python -m scripts.backfill_embeddings --batch-size 32
"""

from __future__ import annotations

import argparse
import logging
import sys

from sqlalchemy import select

from app.db import SessionLocal, init_db
from app.memory.embeddings import build_memory_embedding_input, embed_texts
from app.models.memory import Memory

logger = logging.getLogger(__name__)


def backfill(batch_size: int = 32, dry_run: bool = False) -> int:
    """Returns the number of memories newly embedded."""
    init_db()

    updated = 0
    with SessionLocal() as session:
        rows = list(
            session.execute(
                select(Memory).where(Memory.embedding.is_(None))
            ).scalars()
        )
        if not rows:
            logger.info("No memories to backfill.")
            return 0

        logger.info("Backfilling %d memories…", len(rows))
        for i in range(0, len(rows), batch_size):
            batch = rows[i : i + batch_size]
            inputs = [
                build_memory_embedding_input(
                    narrative_text=m.narrative_text,
                    triggers=list(m.triggers or []),
                    relevance_tags=list(m.relevance_tags or []),
                )
                for m in batch
            ]
            if dry_run:
                logger.info("[dry-run] would embed batch of %d", len(batch))
                continue
            vectors = embed_texts(inputs)
            for m, vec in zip(batch, vectors):
                m.embedding = vec
                updated += 1
            session.commit()
            logger.info("Backfilled %d / %d", min(i + batch_size, len(rows)), len(rows))

    logger.info("Done. Embedded %d memories.", updated)
    return updated


def _main() -> int:
    logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
    ap = argparse.ArgumentParser()
    ap.add_argument("--batch-size", type=int, default=32)
    ap.add_argument("--dry-run", action="store_true")
    args = ap.parse_args()
    backfill(batch_size=args.batch_size, dry_run=args.dry_run)
    return 0


if __name__ == "__main__":
    sys.exit(_main())