DSN / scripts /embed_catalog.py
nexusbert's picture
Add agent workflow documentation and refactor user modeling and recommendation services
1c181b2
Raw
History Blame Contribute Delete
2.67 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
try:
from dotenv import load_dotenv as _load_dotenv
except ImportError:
_load_dotenv = None
def submission_root() -> Path:
return Path(__file__).resolve().parents[1]
def _embed_local(texts: list[str], model_name: str, batch_size: int) -> list[list[float]]:
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise SystemExit("Local embeddings need: pip install sentence-transformers") from e
st = SentenceTransformer(model_name)
arr = st.encode(
texts,
batch_size=batch_size,
convert_to_numpy=True,
normalize_embeddings=False,
show_progress_bar=len(texts) > batch_size,
)
return [row.astype(float).tolist() for row in arr]
def main() -> None:
root = submission_root()
if _load_dotenv is not None:
_load_dotenv(root / ".env")
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=Path, default=root / "data" / "business_catalog.jsonl")
parser.add_argument("--output", type=Path, default=root / "data" / "business_catalog_embedded.jsonl")
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--max-rows", type=int, default=None)
parser.add_argument(
"--model",
type=str,
default=os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "all-MiniLM-L6-v2"),
help="sentence-transformers model id (match runtime TASK_B_LOCAL_EMBEDDING_MODEL).",
)
args = parser.parse_args()
if not args.input.is_file():
raise SystemExit(f"Missing {args.input} — run scripts/build_business_catalog.py first.")
rows: list[dict[str, Any]] = []
with args.input.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
if args.max_rows is not None and len(rows) >= args.max_rows:
break
texts = [r["text_for_embedding"] for r in rows]
embeddings = _embed_local(texts, args.model, args.batch_size)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as fout:
for row, emb in zip(rows, embeddings, strict=True):
row_out = {**row, "embedding": emb}
fout.write(json.dumps(row_out, ensure_ascii=False) + "\n")
print(f"Wrote {len(rows)} embedded rows -> {args.output} (dim={len(embeddings[0])})")
if __name__ == "__main__":
main()