import os import json import torch from flask import Flask, request, jsonify from flask_cors import CORS from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient, models from qdrant_client.models import ( VectorParams, Distance, PointStruct, SparseVectorParams, SparseVector, Modifier ) from fastembed.sparse.bm25 import Bm25 from fastembed.rerank.cross_encoder import TextCrossEncoder from RAG_core.retriever import ( rewrite_query_with_groq, hybrid_search_with_rerank, generate_answer_with_groq, setup_reranker ) import config app = Flask(__name__) CORS(app) # Global variables client = None model = None reranker = None bm25_model = None is_db_ready = False def ensure_payload_indexes(): """Ensure payload indexes exist for filtering.""" global client try: collection_info = client.get_collection(config.COLLECTION_NAME) existing_indexes = list(collection_info.payload_schema.keys()) if hasattr(collection_info, 'payload_schema') else [] if "book_id" not in existing_indexes: client.create_payload_index( collection_name=config.COLLECTION_NAME, field_name="book_id", field_schema=models.KeywordIndexParams(type="keyword") ) if "main_characters" not in existing_indexes: client.create_payload_index( collection_name=config.COLLECTION_NAME, field_name="main_characters", field_schema=models.KeywordIndexParams(type="keyword") ) except Exception as e: print(f"Payload index creation issue: {e}") def initialize_database(): """Initializes database connection and models.""" global client, model, reranker, bm25_model, is_db_ready print("Initializing Backend System...") try: client = QdrantClient(url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY) print("Connected to Qdrant.") except Exception as e: print(f"Qdrant Connection Failed: {e}") return # Check for data file if not os.path.exists(config.DATA_FILE): print(f"Data file '{config.DATA_FILE}' not found.") return with open(config.DATA_FILE, 'r', encoding='utf-8') as f: episodes = json.load(f) print(f"Loaded {len(episodes)} episodes.") # Setup models device = "cuda" if torch.cuda.is_available() else "cpu" model = SentenceTransformer(config.SENTENCE_TRANSFORMER_MODEL, device=device) print(f"Embedding Model Loaded ({device}).") reranker = setup_reranker() print("Reranker Loaded.") bm25_model = Bm25(config.BM25_MODEL_NAME) print("BM25 Loaded.") # Setup Collection collection_exists = client.collection_exists(config.COLLECTION_NAME) if collection_exists: count = client.get_collection(config.COLLECTION_NAME).points_count if count > 0: print(f"Database contains {count} points. Skipping indexing.") ensure_payload_indexes() is_db_ready = True return else: client.delete_collection(config.COLLECTION_NAME) collection_exists = False if not collection_exists: sample_vec = model.encode(episodes[0]["summary"]) embed_dim = len(sample_vec) client.create_collection( collection_name=config.COLLECTION_NAME, vectors_config={ "text_vector": VectorParams(size=embed_dim, distance=Distance.COSINE) }, sparse_vectors_config={ "metadata_sparse": SparseVectorParams(modifier=Modifier.IDF) } ) print(f"Collection '{config.COLLECTION_NAME}' created.") ensure_payload_indexes() # Index data print("Indexing data...") episode_texts = [ep["summary"] + " " + ep["episode_text"] for ep in episodes] metadata_texts = [ep["metadata_string"] for ep in episodes] text_vectors = model.encode(episode_texts, batch_size=8, show_progress_bar=True) sparse_embeddings = list(bm25_model.embed(metadata_texts, batch_size=16)) bm25_vectors = [SparseVector(indices=sp.indices.tolist(), values=sp.values.tolist()) for sp in sparse_embeddings] points = [] for i, ep in enumerate(episodes): points.append(PointStruct( id=i, vector={ "text_vector": text_vectors[i].tolist(), "metadata_sparse": bm25_vectors[i] }, payload=ep )) batch_size = 50 for i in range(0, len(points), batch_size): batch = points[i:i + batch_size] client.upsert(collection_name=config.COLLECTION_NAME, points=batch, wait=True) print(f"Batch {i//batch_size + 1}/{(len(points)-1)//batch_size + 1} uploaded.") print(f"Successfully indexed {len(points)} points.") is_db_ready = True print("Database Build Complete.") @app.route('/health', methods=['GET']) def health(): return jsonify({ "status": "ready" if is_db_ready else "initializing", "collection": config.COLLECTION_NAME, "points": client.get_collection(config.COLLECTION_NAME).points_count if client and is_db_ready else 0 }) @app.route('/status', methods=['GET']) def status(): return health() @app.route('/search', methods=['POST']) def search(): if not is_db_ready or not client: return jsonify({"error": "Database not ready"}), 503 data = request.json user_query = data.get('query', '') if not user_query: return jsonify({"error": "No query provided"}), 400 try: rewritten = rewrite_query_with_groq(user_query) search_results = hybrid_search_with_rerank( semantic_query=rewritten["semantic_query"], metadata_query=rewritten["metadata_hint"], filters=rewritten["filters"], client=client, model=model, reranker=reranker, bm25_model=bm25_model, collection_name=config.COLLECTION_NAME, initial_k=50, final_k=5, use_rrf=True, rerank_weight=0.8, retrieval_weight=0.2 ) if not search_results: return jsonify({ "query": user_query, "rewritten_query": rewritten, "answer": "No relevant passages found in The Odyssey.", "results": [] }) answer_data = generate_answer_with_groq( query=user_query, retrieved_results=search_results, groq_api_key=config.GROQ_API_KEY, rewritten_query=rewritten ) formatted_results = [] for r in search_results: formatted_results.append({ "episode_id": r["episode_id"], "score": float(r["score"]), "reranker_score": float(r["reranker_score"]), "retrieval_score": float(r["retrieval_score"]), "text": r["payload"].get("episode_text", "")[:300] + "...", "summary": r["payload"].get("summary", ""), "book_id": r["payload"].get("book_id"), "main_characters": r["payload"].get("main_characters", []) }) return jsonify({ "query": user_query, "rewritten_query": rewritten.get("semantic_query"), "metadata_hint": rewritten.get("metadata_hint"), "filters_applied": rewritten.get("filters"), "answer": answer_data["answer"], "sources": answer_data["sources"], "results": formatted_results }) except Exception as e: print(f"Search error: {e}") import traceback traceback.print_exc() return jsonify({"error": str(e)}), 500 # Initialize on startup initialize_database() if __name__ == '__main__': app.run(host="0.0.0.0", port=5000, debug=True)