Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient, models | |
| from qdrant_client.models import ( | |
| VectorParams, Distance, PointStruct, SparseVectorParams, | |
| SparseVector, Modifier | |
| ) | |
| from fastembed.sparse.bm25 import Bm25 | |
| from fastembed.rerank.cross_encoder import TextCrossEncoder | |
| from RAG_core.retriever import ( | |
| rewrite_query_with_groq, | |
| hybrid_search_with_rerank, | |
| generate_answer_with_groq, | |
| setup_reranker | |
| ) | |
| import config | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global variables | |
| client = None | |
| model = None | |
| reranker = None | |
| bm25_model = None | |
| is_db_ready = False | |
| def ensure_payload_indexes(): | |
| """Ensure payload indexes exist for filtering.""" | |
| global client | |
| try: | |
| collection_info = client.get_collection(config.COLLECTION_NAME) | |
| existing_indexes = list(collection_info.payload_schema.keys()) if hasattr(collection_info, 'payload_schema') else [] | |
| if "book_id" not in existing_indexes: | |
| client.create_payload_index( | |
| collection_name=config.COLLECTION_NAME, | |
| field_name="book_id", | |
| field_schema=models.KeywordIndexParams(type="keyword") | |
| ) | |
| if "main_characters" not in existing_indexes: | |
| client.create_payload_index( | |
| collection_name=config.COLLECTION_NAME, | |
| field_name="main_characters", | |
| field_schema=models.KeywordIndexParams(type="keyword") | |
| ) | |
| except Exception as e: | |
| print(f"Payload index creation issue: {e}") | |
| def initialize_database(): | |
| """Initializes database connection and models.""" | |
| global client, model, reranker, bm25_model, is_db_ready | |
| print("Initializing Backend System...") | |
| try: | |
| client = QdrantClient(url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY) | |
| print("Connected to Qdrant.") | |
| except Exception as e: | |
| print(f"Qdrant Connection Failed: {e}") | |
| return | |
| # Check for data file | |
| if not os.path.exists(config.DATA_FILE): | |
| print(f"Data file '{config.DATA_FILE}' not found.") | |
| return | |
| with open(config.DATA_FILE, 'r', encoding='utf-8') as f: | |
| episodes = json.load(f) | |
| print(f"Loaded {len(episodes)} episodes.") | |
| # Setup models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SentenceTransformer(config.SENTENCE_TRANSFORMER_MODEL, device=device) | |
| print(f"Embedding Model Loaded ({device}).") | |
| reranker = setup_reranker() | |
| print("Reranker Loaded.") | |
| bm25_model = Bm25(config.BM25_MODEL_NAME) | |
| print("BM25 Loaded.") | |
| # Setup Collection | |
| collection_exists = client.collection_exists(config.COLLECTION_NAME) | |
| if collection_exists: | |
| count = client.get_collection(config.COLLECTION_NAME).points_count | |
| if count > 0: | |
| print(f"Database contains {count} points. Skipping indexing.") | |
| ensure_payload_indexes() | |
| is_db_ready = True | |
| return | |
| else: | |
| client.delete_collection(config.COLLECTION_NAME) | |
| collection_exists = False | |
| if not collection_exists: | |
| sample_vec = model.encode(episodes[0]["summary"]) | |
| embed_dim = len(sample_vec) | |
| client.create_collection( | |
| collection_name=config.COLLECTION_NAME, | |
| vectors_config={ | |
| "text_vector": VectorParams(size=embed_dim, distance=Distance.COSINE) | |
| }, | |
| sparse_vectors_config={ | |
| "metadata_sparse": SparseVectorParams(modifier=Modifier.IDF) | |
| } | |
| ) | |
| print(f"Collection '{config.COLLECTION_NAME}' created.") | |
| ensure_payload_indexes() | |
| # Index data | |
| print("Indexing data...") | |
| episode_texts = [ep["summary"] + " " + ep["episode_text"] for ep in episodes] | |
| metadata_texts = [ep["metadata_string"] for ep in episodes] | |
| text_vectors = model.encode(episode_texts, batch_size=8, show_progress_bar=True) | |
| sparse_embeddings = list(bm25_model.embed(metadata_texts, batch_size=16)) | |
| bm25_vectors = [SparseVector(indices=sp.indices.tolist(), values=sp.values.tolist()) for sp in sparse_embeddings] | |
| points = [] | |
| for i, ep in enumerate(episodes): | |
| points.append(PointStruct( | |
| id=i, | |
| vector={ | |
| "text_vector": text_vectors[i].tolist(), | |
| "metadata_sparse": bm25_vectors[i] | |
| }, | |
| payload=ep | |
| )) | |
| batch_size = 50 | |
| for i in range(0, len(points), batch_size): | |
| batch = points[i:i + batch_size] | |
| client.upsert(collection_name=config.COLLECTION_NAME, points=batch, wait=True) | |
| print(f"Batch {i//batch_size + 1}/{(len(points)-1)//batch_size + 1} uploaded.") | |
| print(f"Successfully indexed {len(points)} points.") | |
| is_db_ready = True | |
| print("Database Build Complete.") | |
| def health(): | |
| return jsonify({ | |
| "status": "ready" if is_db_ready else "initializing", | |
| "collection": config.COLLECTION_NAME, | |
| "points": client.get_collection(config.COLLECTION_NAME).points_count if client and is_db_ready else 0 | |
| }) | |
| def status(): | |
| return health() | |
| def search(): | |
| if not is_db_ready or not client: | |
| return jsonify({"error": "Database not ready"}), 503 | |
| data = request.json | |
| user_query = data.get('query', '') | |
| if not user_query: | |
| return jsonify({"error": "No query provided"}), 400 | |
| try: | |
| rewritten = rewrite_query_with_groq(user_query) | |
| search_results = hybrid_search_with_rerank( | |
| semantic_query=rewritten["semantic_query"], | |
| metadata_query=rewritten["metadata_hint"], | |
| filters=rewritten["filters"], | |
| client=client, | |
| model=model, | |
| reranker=reranker, | |
| bm25_model=bm25_model, | |
| collection_name=config.COLLECTION_NAME, | |
| initial_k=50, | |
| final_k=5, | |
| use_rrf=True, | |
| rerank_weight=0.8, | |
| retrieval_weight=0.2 | |
| ) | |
| if not search_results: | |
| return jsonify({ | |
| "query": user_query, | |
| "rewritten_query": rewritten, | |
| "answer": "No relevant passages found in The Odyssey.", | |
| "results": [] | |
| }) | |
| answer_data = generate_answer_with_groq( | |
| query=user_query, | |
| retrieved_results=search_results, | |
| groq_api_key=config.GROQ_API_KEY, | |
| rewritten_query=rewritten | |
| ) | |
| formatted_results = [] | |
| for r in search_results: | |
| formatted_results.append({ | |
| "episode_id": r["episode_id"], | |
| "score": float(r["score"]), | |
| "reranker_score": float(r["reranker_score"]), | |
| "retrieval_score": float(r["retrieval_score"]), | |
| "text": r["payload"].get("episode_text", "")[:300] + "...", | |
| "summary": r["payload"].get("summary", ""), | |
| "book_id": r["payload"].get("book_id"), | |
| "main_characters": r["payload"].get("main_characters", []) | |
| }) | |
| return jsonify({ | |
| "query": user_query, | |
| "rewritten_query": rewritten.get("semantic_query"), | |
| "metadata_hint": rewritten.get("metadata_hint"), | |
| "filters_applied": rewritten.get("filters"), | |
| "answer": answer_data["answer"], | |
| "sources": answer_data["sources"], | |
| "results": formatted_results | |
| }) | |
| except Exception as e: | |
| print(f"Search error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| # Initialize on startup | |
| initialize_database() | |
| if __name__ == '__main__': | |
| app.run(host="0.0.0.0", port=5000, debug=True) |