ragbackend / setup.py
mils2004's picture
phase 1 deployment
95d9f3c
Raw
History Blame Contribute Delete
6.94 kB
import argparse
import json
import os
import sys
import time
import numpy as np
from sentence_transformers import SentenceTransformer
#Args
def parse_args():
p = argparse.ArgumentParser(description="Build all RAG pipeline indexes.")
p.add_argument("--upload",action="store_true",help="Upload indexes to hugging face repository")
p.add_argument("--papers", type=int, default=10_000,
help="Number of arXiv papers to load (default: 10000)")
p.add_argument("--skip-data", action="store_true",
help="Skip dataset download; load CHUNKS_PATH directly")
p.add_argument("--chunks-file", type=str, default=None,
help="Path to a pre-built chunks JSON file")
p.add_argument("--only", type=str, default=None,
choices=["embeddings", "bm25", "qdrant"],
help="Rebuild only one index (requires chunks + embeddings to exist)")
p.add_argument("--batch-size", type=int, default=512,
help="Qdrant upsert batch size (default: 512)")
return p.parse_args()
#Helper functions
def step(msg: str):
print(f"\n{'='*60}\n {msg}\n{'='*60}")
def upload_to_hub(config):
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(
repo_id = config.HF_REPO_ID,
repo_type="dataset",
exist_ok=True,
private=False
)
print("[INFO] uploading bm25 indexes")
api.upload_folder(
folder_path=config.BM25_INDEX_PATH,
repo_id = config.HF_REPO_ID,
repo_type="dataset",
path_in_repo="bm25_index_10k"
)
print("[INFO] Uploading chunks(JSON)")
api.upload_file(
path_or_fileobj=config.CHUNKS_PATH,
path_in_repo="all_chunks_slim.json",
repo_id=config.HF_REPO_ID,
repo_type="dataset"
)
print("[INFO] uploading the embeddings")
api.upload_file(
path_or_fileobj=config.EMBEDDINGS_PATH,
path_in_repo="chunk_embeddings_10k (1).npy",
repo_id=config.HF_REPO_ID,
repo_type="dataset"
)
print(f"[INFO] Done all files uploaded to {config.HF_REPO_ID}")
def check_env():
missing = [k for k in ("GROQ_API_KEY",) if not os.getenv(k)]
if missing:
print(f"[WARN] Missing env vars: {', '.join(missing)}")
print(" The pipeline will fail at query time without these.")
print(" Set them in a .env file or your shell before running queries.\n")
#Loading and chunking
def build_chunks(args, config) -> list[dict]:
# --chunks-file or --skip-data: load existing JSON
source = args.chunks_file or (config.CHUNKS_PATH if args.skip_data else None)
if source:
if not os.path.exists(source):
sys.exit(f"[ERROR] Chunks file not found: {source}")
step(f"Loading chunks from {source}")
with open(source) as f:
chunks = json.load(f)
print(f"[INFO] Loaded {len(chunks)} chunks")
return chunks
step(f"Downloading arXiv dataset ({args.papers} papers)")
from rag_pipeline.data.load_data import load_arxiv_subset, process_dataset
data = load_arxiv_subset(n=args.papers)
chunks = process_dataset(data)
out = config.CHUNKS_PATH
with open(out, "w") as f:
json.dump(chunks, f)
print(f"[INFO] Saved {len(chunks)} chunks β†’ {out}")
return chunks
#Embedding
def build_embeddings(chunks: list[dict], config) -> "np.ndarray":
out = config.EMBEDDINGS_PATH
if os.path.exists(out):
print(f"[INFO] Embeddings already exist at {out} β€” skipping.")
print(" Delete the file and re-run if you want to rebuild.")
return np.load(out)
step(f"Encoding {len(chunks)} chunks with {config.BI_ENCODER_MODEL}")
model = SentenceTransformer(config.BI_ENCODER_MODEL)
texts = [c["text"] for c in chunks]
t0 = time.time()
embeddings = model.encode(
texts,
batch_size=256,
show_progress_bar=True,
convert_to_numpy=True,
)
print(f"[INFO] Encoded in {time.time()-t0:.1f}s shape={embeddings.shape}")
np.save(out, embeddings)
print(f"[INFO] Saved embeddings β†’ {out}")
return embeddings
#BM25 index
def build_bm25(chunks: list[dict], config):
out = config.BM25_INDEX_PATH
if os.path.exists(out):
print(f"[INFO] BM25 index already exists at {out} β€” skipping.")
print(" Delete the folder and re-run if you want to rebuild.")
return
step(f"Building BM25 index β†’ {out}")
from rag_pipeline.retrieval.bm25 import create_bm25
create_bm25(chunks)
#Qdrant Vector DB
def build_qdrant(chunks: list[dict], embeddings, config):
step(f"Building Qdrant collection '{config.COLLECTION_NAME}'")
print("[INFO] Make sure Qdrant is running: docker run -p 6333:6333 qdrant/qdrant")
try:
from qdrant_client import QdrantClient
client = QdrantClient(host="localhost", port=6333)
client.get_collections() # connection check
except Exception as e:
sys.exit(f"[ERROR] Cannot connect to Qdrant: {e}\n"
f" Start it with: docker run -p 6333:6333 qdrant/qdrant")
from rag_pipeline.retrieval.semantic import create_qdrant_index
create_qdrant_index(chunks, embeddings, client, batch_size=args.batch_size)
def main():
global args
args = parse_args()
check_env()
sys.path.insert(0, os.path.dirname(__file__))
import rag_pipeline.config as config
if args.upload:
upload_to_hub(config)
return
if args.only:
import numpy as np
with open(config.CHUNKS_PATH) as f:
chunks = json.load(f)
if args.only == "embeddings":
# Force rebuild by removing existing file
if os.path.exists(config.EMBEDDINGS_PATH):
os.remove(config.EMBEDDINGS_PATH)
build_embeddings(chunks, config)
elif args.only == "bm25":
import shutil
if os.path.exists(config.BM25_INDEX_PATH):
shutil.rmtree(config.BM25_INDEX_PATH)
build_bm25(chunks, config)
elif args.only == "qdrant":
embeddings = np.load(config.EMBEDDINGS_PATH)
build_qdrant(chunks, embeddings, config)
print("\n[DONE]")
return
chunks = build_chunks(args, config)
embeddings = build_embeddings(chunks, config)
build_bm25(chunks, config)
build_qdrant(chunks, embeddings, config)
print(f"""
{'='*60}
Setup complete!
Files created:
{config.CHUNKS_PATH}
{config.EMBEDDINGS_PATH}
{config.BM25_INDEX_PATH}/
Qdrant collection: '{config.COLLECTION_NAME}'
Next steps:
1. Set GROQ_API_KEY in your environment
2. from rag_pipeline.app import Pipeline
pipe = Pipeline()
print(pipe.query("what is swarm optimization"))
{'='*60}
""")
if __name__ == "__main__":
main()