github-actions
Sync from GitHub 2025-12-17T12:18:53Z
5a3b322
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
from tqdm import tqdm
from data.catalog_loader import make_assessment_id
from models.embedding_model import EmbeddingModel
def generate_embeddings(catalog_path: str, model_name: str, batch_size: int = 32, output_dir: str = "data/embeddings") -> Tuple[np.ndarray, List[str]]:
df = pd.read_json(catalog_path, lines=True) if catalog_path.endswith(".jsonl") else pd.read_parquet(catalog_path)
if "assessment_id" not in df.columns:
if "url" in df.columns:
df["assessment_id"] = df["url"].apply(make_assessment_id)
else:
raise KeyError("assessment_id not found and url missing to derive it.")
df = df.sort_values("assessment_id")
texts = df["doc_text"].tolist()
ids = df["assessment_id"].tolist()
model = EmbeddingModel(model_name)
embeddings: List[np.ndarray] = []
start = time.time()
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
batch = texts[i : i + batch_size]
embeds = model.encode(batch, normalize=True, batch_size=batch_size, is_query=False)
embeddings.append(embeds)
embeddings_arr = np.vstack(embeddings).astype(np.float32)
Path(output_dir).mkdir(parents=True, exist_ok=True)
np.save(Path(output_dir) / "embeddings.npy", embeddings_arr)
with open(Path(output_dir) / "assessment_ids.json", "w") as f:
json.dump(ids, f, indent=2)
total_time = time.time() - start
log = {
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"model_name": model_name,
"num_documents": len(texts),
"embedding_dim": embeddings_arr.shape[1],
"batch_size": batch_size,
"total_time_seconds": total_time,
"avg_time_per_doc_ms": (total_time / len(texts) * 1000) if len(texts) else None,
"normalized": True,
"catalog_path": catalog_path,
}
with open(Path(output_dir) / "generation_log.json", "w") as f:
json.dump(log, f, indent=2)
return embeddings_arr, ids
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--catalog", required=True, help="Enriched catalog with doc_text")
parser.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--output-dir", default="data/embeddings")
args = parser.parse_args()
generate_embeddings(args.catalog, args.model, batch_size=args.batch_size, output_dir=args.output_dir)