Spaces:
Sleeping
Sleeping
File size: 6,936 Bytes
9573fe1 95d9f3c 9573fe1 95d9f3c 9573fe1 95d9f3c 9573fe1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import argparse
import json
import os
import sys
import time
import numpy as np
from sentence_transformers import SentenceTransformer
#Args
def parse_args():
p = argparse.ArgumentParser(description="Build all RAG pipeline indexes.")
p.add_argument("--upload",action="store_true",help="Upload indexes to hugging face repository")
p.add_argument("--papers", type=int, default=10_000,
help="Number of arXiv papers to load (default: 10000)")
p.add_argument("--skip-data", action="store_true",
help="Skip dataset download; load CHUNKS_PATH directly")
p.add_argument("--chunks-file", type=str, default=None,
help="Path to a pre-built chunks JSON file")
p.add_argument("--only", type=str, default=None,
choices=["embeddings", "bm25", "qdrant"],
help="Rebuild only one index (requires chunks + embeddings to exist)")
p.add_argument("--batch-size", type=int, default=512,
help="Qdrant upsert batch size (default: 512)")
return p.parse_args()
#Helper functions
def step(msg: str):
print(f"\n{'='*60}\n {msg}\n{'='*60}")
def upload_to_hub(config):
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(
repo_id = config.HF_REPO_ID,
repo_type="dataset",
exist_ok=True,
private=False
)
print("[INFO] uploading bm25 indexes")
api.upload_folder(
folder_path=config.BM25_INDEX_PATH,
repo_id = config.HF_REPO_ID,
repo_type="dataset",
path_in_repo="bm25_index_10k"
)
print("[INFO] Uploading chunks(JSON)")
api.upload_file(
path_or_fileobj=config.CHUNKS_PATH,
path_in_repo="all_chunks_slim.json",
repo_id=config.HF_REPO_ID,
repo_type="dataset"
)
print("[INFO] uploading the embeddings")
api.upload_file(
path_or_fileobj=config.EMBEDDINGS_PATH,
path_in_repo="chunk_embeddings_10k (1).npy",
repo_id=config.HF_REPO_ID,
repo_type="dataset"
)
print(f"[INFO] Done all files uploaded to {config.HF_REPO_ID}")
def check_env():
missing = [k for k in ("GROQ_API_KEY",) if not os.getenv(k)]
if missing:
print(f"[WARN] Missing env vars: {', '.join(missing)}")
print(" The pipeline will fail at query time without these.")
print(" Set them in a .env file or your shell before running queries.\n")
#Loading and chunking
def build_chunks(args, config) -> list[dict]:
# --chunks-file or --skip-data: load existing JSON
source = args.chunks_file or (config.CHUNKS_PATH if args.skip_data else None)
if source:
if not os.path.exists(source):
sys.exit(f"[ERROR] Chunks file not found: {source}")
step(f"Loading chunks from {source}")
with open(source) as f:
chunks = json.load(f)
print(f"[INFO] Loaded {len(chunks)} chunks")
return chunks
step(f"Downloading arXiv dataset ({args.papers} papers)")
from rag_pipeline.data.load_data import load_arxiv_subset, process_dataset
data = load_arxiv_subset(n=args.papers)
chunks = process_dataset(data)
out = config.CHUNKS_PATH
with open(out, "w") as f:
json.dump(chunks, f)
print(f"[INFO] Saved {len(chunks)} chunks β {out}")
return chunks
#Embedding
def build_embeddings(chunks: list[dict], config) -> "np.ndarray":
out = config.EMBEDDINGS_PATH
if os.path.exists(out):
print(f"[INFO] Embeddings already exist at {out} β skipping.")
print(" Delete the file and re-run if you want to rebuild.")
return np.load(out)
step(f"Encoding {len(chunks)} chunks with {config.BI_ENCODER_MODEL}")
model = SentenceTransformer(config.BI_ENCODER_MODEL)
texts = [c["text"] for c in chunks]
t0 = time.time()
embeddings = model.encode(
texts,
batch_size=256,
show_progress_bar=True,
convert_to_numpy=True,
)
print(f"[INFO] Encoded in {time.time()-t0:.1f}s shape={embeddings.shape}")
np.save(out, embeddings)
print(f"[INFO] Saved embeddings β {out}")
return embeddings
#BM25 index
def build_bm25(chunks: list[dict], config):
out = config.BM25_INDEX_PATH
if os.path.exists(out):
print(f"[INFO] BM25 index already exists at {out} β skipping.")
print(" Delete the folder and re-run if you want to rebuild.")
return
step(f"Building BM25 index β {out}")
from rag_pipeline.retrieval.bm25 import create_bm25
create_bm25(chunks)
#Qdrant Vector DB
def build_qdrant(chunks: list[dict], embeddings, config):
step(f"Building Qdrant collection '{config.COLLECTION_NAME}'")
print("[INFO] Make sure Qdrant is running: docker run -p 6333:6333 qdrant/qdrant")
try:
from qdrant_client import QdrantClient
client = QdrantClient(host="localhost", port=6333)
client.get_collections() # connection check
except Exception as e:
sys.exit(f"[ERROR] Cannot connect to Qdrant: {e}\n"
f" Start it with: docker run -p 6333:6333 qdrant/qdrant")
from rag_pipeline.retrieval.semantic import create_qdrant_index
create_qdrant_index(chunks, embeddings, client, batch_size=args.batch_size)
def main():
global args
args = parse_args()
check_env()
sys.path.insert(0, os.path.dirname(__file__))
import rag_pipeline.config as config
if args.upload:
upload_to_hub(config)
return
if args.only:
import numpy as np
with open(config.CHUNKS_PATH) as f:
chunks = json.load(f)
if args.only == "embeddings":
# Force rebuild by removing existing file
if os.path.exists(config.EMBEDDINGS_PATH):
os.remove(config.EMBEDDINGS_PATH)
build_embeddings(chunks, config)
elif args.only == "bm25":
import shutil
if os.path.exists(config.BM25_INDEX_PATH):
shutil.rmtree(config.BM25_INDEX_PATH)
build_bm25(chunks, config)
elif args.only == "qdrant":
embeddings = np.load(config.EMBEDDINGS_PATH)
build_qdrant(chunks, embeddings, config)
print("\n[DONE]")
return
chunks = build_chunks(args, config)
embeddings = build_embeddings(chunks, config)
build_bm25(chunks, config)
build_qdrant(chunks, embeddings, config)
print(f"""
{'='*60}
Setup complete!
Files created:
{config.CHUNKS_PATH}
{config.EMBEDDINGS_PATH}
{config.BM25_INDEX_PATH}/
Qdrant collection: '{config.COLLECTION_NAME}'
Next steps:
1. Set GROQ_API_KEY in your environment
2. from rag_pipeline.app import Pipeline
pipe = Pipeline()
print(pipe.query("what is swarm optimization"))
{'='*60}
""")
if __name__ == "__main__":
main() |