#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any try: from dotenv import load_dotenv as _load_dotenv except ImportError: _load_dotenv = None def submission_root() -> Path: return Path(__file__).resolve().parents[1] def _embed_local(texts: list[str], model_name: str, batch_size: int) -> list[list[float]]: try: from sentence_transformers import SentenceTransformer except ImportError as e: raise SystemExit("Local embeddings need: pip install sentence-transformers") from e st = SentenceTransformer(model_name) arr = st.encode( texts, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=False, show_progress_bar=len(texts) > batch_size, ) return [row.astype(float).tolist() for row in arr] def main() -> None: root = submission_root() if _load_dotenv is not None: _load_dotenv(root / ".env") parser = argparse.ArgumentParser() parser.add_argument("--input", type=Path, default=root / "data" / "business_catalog.jsonl") parser.add_argument("--output", type=Path, default=root / "data" / "business_catalog_embedded.jsonl") parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--max-rows", type=int, default=None) parser.add_argument( "--model", type=str, default=os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "all-MiniLM-L6-v2"), help="sentence-transformers model id (match runtime TASK_B_LOCAL_EMBEDDING_MODEL).", ) args = parser.parse_args() if not args.input.is_file(): raise SystemExit(f"Missing {args.input} — run scripts/build_business_catalog.py first.") rows: list[dict[str, Any]] = [] with args.input.open(encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) if args.max_rows is not None and len(rows) >= args.max_rows: break texts = [r["text_for_embedding"] for r in rows] embeddings = _embed_local(texts, args.model, args.batch_size) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as fout: for row, emb in zip(rows, embeddings, strict=True): row_out = {**row, "embedding": emb} fout.write(json.dumps(row_out, ensure_ascii=False) + "\n") print(f"Wrote {len(rows)} embedded rows -> {args.output} (dim={len(embeddings[0])})") if __name__ == "__main__": main()