| |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| try: |
| from dotenv import load_dotenv as _load_dotenv |
| except ImportError: |
| _load_dotenv = None |
|
|
|
|
| def submission_root() -> Path: |
| return Path(__file__).resolve().parents[1] |
|
|
|
|
| def _embed_local(texts: list[str], model_name: str, batch_size: int) -> list[list[float]]: |
| try: |
| from sentence_transformers import SentenceTransformer |
| except ImportError as e: |
| raise SystemExit("Local embeddings need: pip install sentence-transformers") from e |
|
|
| st = SentenceTransformer(model_name) |
| arr = st.encode( |
| texts, |
| batch_size=batch_size, |
| convert_to_numpy=True, |
| normalize_embeddings=False, |
| show_progress_bar=len(texts) > batch_size, |
| ) |
| return [row.astype(float).tolist() for row in arr] |
|
|
|
|
| def main() -> None: |
| root = submission_root() |
| if _load_dotenv is not None: |
| _load_dotenv(root / ".env") |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input", type=Path, default=root / "data" / "business_catalog.jsonl") |
| parser.add_argument("--output", type=Path, default=root / "data" / "business_catalog_embedded.jsonl") |
| parser.add_argument("--batch-size", type=int, default=64) |
| parser.add_argument("--max-rows", type=int, default=None) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default=os.environ.get("TASK_B_LOCAL_EMBEDDING_MODEL", "all-MiniLM-L6-v2"), |
| help="sentence-transformers model id (match runtime TASK_B_LOCAL_EMBEDDING_MODEL).", |
| ) |
| args = parser.parse_args() |
|
|
| if not args.input.is_file(): |
| raise SystemExit(f"Missing {args.input} — run scripts/build_business_catalog.py first.") |
|
|
| rows: list[dict[str, Any]] = [] |
| with args.input.open(encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| rows.append(json.loads(line)) |
| if args.max_rows is not None and len(rows) >= args.max_rows: |
| break |
|
|
| texts = [r["text_for_embedding"] for r in rows] |
| embeddings = _embed_local(texts, args.model, args.batch_size) |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| with args.output.open("w", encoding="utf-8") as fout: |
| for row, emb in zip(rows, embeddings, strict=True): |
| row_out = {**row, "embedding": emb} |
| fout.write(json.dumps(row_out, ensure_ascii=False) + "\n") |
|
|
| print(f"Wrote {len(rows)} embedded rows -> {args.output} (dim={len(embeddings[0])})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|