File size: 2,132 Bytes
b308b74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env bash
set -euo pipefail

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd)"
STORE_DIR="$ROOT_DIR/memory_store"
CKPT_DIR="$ROOT_DIR/ckpts/ssl_pred_moe_lora"
DATA_DIR="$ROOT_DIR/data/processed"
DIM=${DIM:-768}
NLIST=${NLIST:-4096}
MAX_TOKS=${MAX_TOKS:-512}

mkdir -p "$STORE_DIR"

python - "$CKPT_DIR" "$DATA_DIR" "$STORE_DIR" "$DIM" "$NLIST" "$MAX_TOKS" <<'PY'
import os, sys, glob, numpy as np, pandas as pd
import pyarrow as pa, pyarrow.parquet as pq
import torch, faiss
from transformers import AutoModel, AutoTokenizer

ckpt, data_dir, store_dir, dim, nlist, max_toks = sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4]), int(sys.argv[5]), int(sys.argv[6])

model = AutoModel.from_pretrained(ckpt).eval().cuda() if torch.cuda.is_available() else AutoModel.from_pretrained(ckpt).eval()
tok = AutoTokenizer.from_pretrained(ckpt)

# IVF index
quantizer = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)

vecs = []
metas = []
files = sorted(glob.glob(os.path.join(data_dir, '*.txt')))
if not files:
    raise SystemExit(f"No .txt files in {data_dir}")

with torch.no_grad():
    for fn in files:
        text = open(fn, 'r', encoding='utf-8', errors='ignore').read()
        if not text.strip():
            continue
        ids = tok(text, return_tensors='pt', truncation=True, max_length=max_toks)
        if torch.cuda.is_available():
            ids = {k: v.cuda() for k, v in ids.items()}
        h = model(**ids).last_hidden_state[:, 0, :].detach().cpu().numpy()
        faiss.normalize_L2(h)
        vecs.append(h)
        metas.append(pd.DataFrame({
            'file': [os.path.basename(fn)] * h.shape[0],
            'len': [len(text)] * h.shape[0]
        }))

X = np.concatenate(vecs, axis=0)
index.train(X)
index.add(X)
faiss.write_index(index, os.path.join(store_dir, 'episodic.faiss'))

meta = pd.concat(metas, ignore_index=True)
pq.write_table(pa.Table.from_pandas(meta), os.path.join(store_dir, 'meta.parquet'))
print(f"Built FAISS index with {X.shape[0]} vectors → {store_dir}")
PY

echo "Memory built → $STORE_DIR/episodic.faiss"