odyssey-rag / api.py
rikodon72's picture
Upload 7 files
a8cd31b verified
import os
import json
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient, models
from qdrant_client.models import (
VectorParams, Distance, PointStruct, SparseVectorParams,
SparseVector, Modifier
)
from fastembed.sparse.bm25 import Bm25
from fastembed.rerank.cross_encoder import TextCrossEncoder
from RAG_core.retriever import (
rewrite_query_with_groq,
hybrid_search_with_rerank,
generate_answer_with_groq,
setup_reranker
)
import config
app = Flask(__name__)
CORS(app)
# Global variables
client = None
model = None
reranker = None
bm25_model = None
is_db_ready = False
def ensure_payload_indexes():
"""Ensure payload indexes exist for filtering."""
global client
try:
collection_info = client.get_collection(config.COLLECTION_NAME)
existing_indexes = list(collection_info.payload_schema.keys()) if hasattr(collection_info, 'payload_schema') else []
if "book_id" not in existing_indexes:
client.create_payload_index(
collection_name=config.COLLECTION_NAME,
field_name="book_id",
field_schema=models.KeywordIndexParams(type="keyword")
)
if "main_characters" not in existing_indexes:
client.create_payload_index(
collection_name=config.COLLECTION_NAME,
field_name="main_characters",
field_schema=models.KeywordIndexParams(type="keyword")
)
except Exception as e:
print(f"Payload index creation issue: {e}")
def initialize_database():
"""Initializes database connection and models."""
global client, model, reranker, bm25_model, is_db_ready
print("Initializing Backend System...")
try:
client = QdrantClient(url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY)
print("Connected to Qdrant.")
except Exception as e:
print(f"Qdrant Connection Failed: {e}")
return
# Check for data file
if not os.path.exists(config.DATA_FILE):
print(f"Data file '{config.DATA_FILE}' not found.")
return
with open(config.DATA_FILE, 'r', encoding='utf-8') as f:
episodes = json.load(f)
print(f"Loaded {len(episodes)} episodes.")
# Setup models
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer(config.SENTENCE_TRANSFORMER_MODEL, device=device)
print(f"Embedding Model Loaded ({device}).")
reranker = setup_reranker()
print("Reranker Loaded.")
bm25_model = Bm25(config.BM25_MODEL_NAME)
print("BM25 Loaded.")
# Setup Collection
collection_exists = client.collection_exists(config.COLLECTION_NAME)
if collection_exists:
count = client.get_collection(config.COLLECTION_NAME).points_count
if count > 0:
print(f"Database contains {count} points. Skipping indexing.")
ensure_payload_indexes()
is_db_ready = True
return
else:
client.delete_collection(config.COLLECTION_NAME)
collection_exists = False
if not collection_exists:
sample_vec = model.encode(episodes[0]["summary"])
embed_dim = len(sample_vec)
client.create_collection(
collection_name=config.COLLECTION_NAME,
vectors_config={
"text_vector": VectorParams(size=embed_dim, distance=Distance.COSINE)
},
sparse_vectors_config={
"metadata_sparse": SparseVectorParams(modifier=Modifier.IDF)
}
)
print(f"Collection '{config.COLLECTION_NAME}' created.")
ensure_payload_indexes()
# Index data
print("Indexing data...")
episode_texts = [ep["summary"] + " " + ep["episode_text"] for ep in episodes]
metadata_texts = [ep["metadata_string"] for ep in episodes]
text_vectors = model.encode(episode_texts, batch_size=8, show_progress_bar=True)
sparse_embeddings = list(bm25_model.embed(metadata_texts, batch_size=16))
bm25_vectors = [SparseVector(indices=sp.indices.tolist(), values=sp.values.tolist()) for sp in sparse_embeddings]
points = []
for i, ep in enumerate(episodes):
points.append(PointStruct(
id=i,
vector={
"text_vector": text_vectors[i].tolist(),
"metadata_sparse": bm25_vectors[i]
},
payload=ep
))
batch_size = 50
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
client.upsert(collection_name=config.COLLECTION_NAME, points=batch, wait=True)
print(f"Batch {i//batch_size + 1}/{(len(points)-1)//batch_size + 1} uploaded.")
print(f"Successfully indexed {len(points)} points.")
is_db_ready = True
print("Database Build Complete.")
@app.route('/health', methods=['GET'])
def health():
return jsonify({
"status": "ready" if is_db_ready else "initializing",
"collection": config.COLLECTION_NAME,
"points": client.get_collection(config.COLLECTION_NAME).points_count if client and is_db_ready else 0
})
@app.route('/status', methods=['GET'])
def status():
return health()
@app.route('/search', methods=['POST'])
def search():
if not is_db_ready or not client:
return jsonify({"error": "Database not ready"}), 503
data = request.json
user_query = data.get('query', '')
if not user_query:
return jsonify({"error": "No query provided"}), 400
try:
rewritten = rewrite_query_with_groq(user_query)
search_results = hybrid_search_with_rerank(
semantic_query=rewritten["semantic_query"],
metadata_query=rewritten["metadata_hint"],
filters=rewritten["filters"],
client=client,
model=model,
reranker=reranker,
bm25_model=bm25_model,
collection_name=config.COLLECTION_NAME,
initial_k=50,
final_k=5,
use_rrf=True,
rerank_weight=0.8,
retrieval_weight=0.2
)
if not search_results:
return jsonify({
"query": user_query,
"rewritten_query": rewritten,
"answer": "No relevant passages found in The Odyssey.",
"results": []
})
answer_data = generate_answer_with_groq(
query=user_query,
retrieved_results=search_results,
groq_api_key=config.GROQ_API_KEY,
rewritten_query=rewritten
)
formatted_results = []
for r in search_results:
formatted_results.append({
"episode_id": r["episode_id"],
"score": float(r["score"]),
"reranker_score": float(r["reranker_score"]),
"retrieval_score": float(r["retrieval_score"]),
"text": r["payload"].get("episode_text", "")[:300] + "...",
"summary": r["payload"].get("summary", ""),
"book_id": r["payload"].get("book_id"),
"main_characters": r["payload"].get("main_characters", [])
})
return jsonify({
"query": user_query,
"rewritten_query": rewritten.get("semantic_query"),
"metadata_hint": rewritten.get("metadata_hint"),
"filters_applied": rewritten.get("filters"),
"answer": answer_data["answer"],
"sources": answer_data["sources"],
"results": formatted_results
})
except Exception as e:
print(f"Search error: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": str(e)}), 500
# Initialize on startup
initialize_database()
if __name__ == '__main__':
app.run(host="0.0.0.0", port=5000, debug=True)