Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| import logging | |
| import time | |
| import json | |
| import hashlib | |
| from concurrent.futures import ThreadPoolExecutor | |
| from threading import Lock | |
| import re | |
| # Set up cache directory for HuggingFace models | |
| cache_dir = os.path.join(os.getcwd(), ".cache") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ['HF_HOME'] = cache_dir | |
| # Suppress TensorFlow warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| os.environ['TF_LOGGING_LEVEL'] = 'ERROR' | |
| os.environ['TF_ENABLE_DEPRECATION_WARNINGS'] = '0' | |
| warnings.filterwarnings('ignore', category=DeprecationWarning, module='tensorflow') | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| from fastapi import FastAPI, HTTPException, Depends, Header, Query, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from content_readers import parse_document_url, parse_document_file | |
| from embedder import build_faiss_index, preload_model | |
| from retriever import retrieve_chunks | |
| from llm import query_gemini | |
| import uvicorn | |
| from contextlib import asynccontextmanager | |
| # Import Supabase logger | |
| from db_logger import log_query | |
| # Helper to get real client IP | |
| def get_client_ip(request: Request): | |
| forwarded_for = request.headers.get("x-forwarded-for") | |
| if forwarded_for: | |
| return forwarded_for.split(",")[0].strip() | |
| real_ip = request.headers.get("x-real-ip") | |
| if real_ip: | |
| return real_ip | |
| return request.client.host | |
| async def lifespan(app: FastAPI): | |
| print("Starting up HackRx Insurance Policy Assistant...") | |
| print("Preloading sentence transformer model...") | |
| preload_model() | |
| print("Model preloading completed. API is ready to serve requests!") | |
| yield | |
| app = FastAPI(title="HackRx Insurance Policy Assistant", version="3.2.6", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"message": "HackRx Insurance Policy Assistant API is running!"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| class QueryRequest(BaseModel): | |
| documents: str | |
| questions: list[str] | |
| class LocalQueryRequest(BaseModel): | |
| document_path: str | |
| questions: list[str] | |
| def verify_token(authorization: str = Header(None)): | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Invalid authorization header") | |
| token = authorization.replace("Bearer ", "") | |
| if not token: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| return token | |
| def process_batch(batch_questions, context_chunks): | |
| return query_gemini(batch_questions, context_chunks) | |
| def get_document_id_from_url(url: str) -> str: | |
| return hashlib.md5(url.encode()).hexdigest() | |
| def question_has_https_link(q: str) -> bool: | |
| return bool(re.search(r"https://[^\s]+", q)) | |
| # Document cache with thread safety | |
| doc_cache = {} | |
| doc_cache_lock = Lock() | |
| async def clear_cache(doc_id: str = Query(None), | |
| url: str = Query(None), | |
| doc_only: bool = Query(False)): | |
| cleared = {} | |
| if url: | |
| doc_id = get_document_id_from_url(url) | |
| if doc_id: | |
| if not doc_only: | |
| with doc_cache_lock: | |
| if doc_id in doc_cache: | |
| del doc_cache[doc_id] | |
| cleared["doc_cache"] = f"Cleared document {doc_id}" | |
| else: | |
| if not doc_only: | |
| with doc_cache_lock: | |
| doc_cache.clear() | |
| cleared["doc_cache"] = "Cleared ALL documents" | |
| return {"status": "success", "cleared": cleared} | |
| async def run_query(request: QueryRequest, fastapi_request: Request, token: str = Depends(verify_token)): | |
| start_time = time.time() | |
| timing_data = {} | |
| try: | |
| user_ip = get_client_ip(fastapi_request) | |
| user_agent = fastapi_request.headers.get("user-agent", "Unknown") | |
| print("=== INPUT JSON ===") | |
| print(json.dumps({"documents": request.documents, "questions": request.questions}, indent=2)) | |
| print("==================\n") | |
| doc_id = get_document_id_from_url(request.documents or "") | |
| with doc_cache_lock: | |
| if doc_id in doc_cache: | |
| print("✅ Using cached document...") | |
| cached = doc_cache[doc_id] | |
| text_chunks = cached["chunks"] | |
| index = cached["index"] | |
| texts = cached["texts"] | |
| else: | |
| print("⚙️ Parsing and indexing new document...") | |
| pdf_start = time.time() | |
| text_chunks = parse_document_url(request.documents) | |
| timing_data['pdf_parsing'] = round(time.time() - pdf_start, 2) | |
| index_start = time.time() | |
| index, texts = build_faiss_index(text_chunks) | |
| timing_data['faiss_index_building'] = round(time.time() - index_start, 2) | |
| doc_cache[doc_id] = { | |
| "chunks": text_chunks, | |
| "index": index, | |
| "texts": texts | |
| } | |
| retrieval_start = time.time() | |
| all_chunks = set() | |
| for idx, question in enumerate(request.questions): | |
| top_chunks = retrieve_chunks(index, texts, question) | |
| all_chunks.update(top_chunks) | |
| timing_data['chunk_retrieval'] = round(time.time() - retrieval_start, 2) | |
| context_chunks = list(all_chunks) | |
| batch_size = 10 | |
| batches = [(i, request.questions[i:i + batch_size]) for i in range(0, len(request.questions), batch_size)] | |
| llm_start = time.time() | |
| results_dict = {} | |
| with ThreadPoolExecutor(max_workers=min(5, len(batches))) as executor: | |
| futures = [executor.submit(process_batch, batch, context_chunks) for _, batch in batches] | |
| for (start_idx, batch), future in zip(batches, futures): | |
| try: | |
| result = future.result() | |
| if isinstance(result, dict) and "answers" in result: | |
| for j, answer in enumerate(result["answers"]): | |
| results_dict[start_idx + j] = answer | |
| else: | |
| for j in range(len(batch)): | |
| results_dict[start_idx + j] = "Error in response" | |
| except Exception as e: | |
| for j in range(len(batch)): | |
| results_dict[start_idx + j] = f"Error: {str(e)}" | |
| timing_data['llm_processing'] = round(time.time() - llm_start, 2) | |
| responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))] | |
| total_time = time.time() - start_time | |
| timing_data['total_time'] = round(total_time, 2) | |
| # Log to Supabase with user_agent + geo_location | |
| for q, a in zip(request.questions, responses): | |
| log_query( | |
| document_source=request.documents or "UNKNOWN", | |
| question=q, | |
| answer=a, | |
| ip_address=user_ip, | |
| user_agent=user_agent, | |
| response_time=total_time | |
| ) | |
| return {"answers": responses} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def run_local_query(request: LocalQueryRequest, fastapi_request: Request): | |
| start_time = time.time() | |
| timing_data = {} | |
| try: | |
| user_ip = get_client_ip(fastapi_request) | |
| user_agent = fastapi_request.headers.get("user-agent", "Unknown") | |
| print("=== INPUT JSON ===") | |
| print(json.dumps({"document_path": request.document_path, "questions": request.questions}, indent=2)) | |
| print("==================\n") | |
| pdf_start = time.time() | |
| text_chunks = parse_document_file(request.document_path) | |
| timing_data['pdf_parsing'] = round(time.time() - pdf_start, 2) | |
| index_start = time.time() | |
| index, texts = build_faiss_index(text_chunks) | |
| timing_data['faiss_index_building'] = round(time.time() - index_start, 2) | |
| retrieval_start = time.time() | |
| all_chunks = set() | |
| for question in request.questions: | |
| top_chunks = retrieve_chunks(index, texts, question) | |
| all_chunks.update(top_chunks) | |
| timing_data['chunk_retrieval'] = round(time.time() - retrieval_start, 2) | |
| context_chunks = list(all_chunks) | |
| batch_size = 20 | |
| batches = [(i, request.questions[i:i + batch_size]) for i in range(0, len(request.questions), batch_size)] | |
| llm_start = time.time() | |
| results_dict = {} | |
| with ThreadPoolExecutor(max_workers=min(5, len(batches))) as executor: | |
| futures = [executor.submit(process_batch, batch, context_chunks) for _, batch in batches] | |
| for (start_idx, batch), future in zip(batches, futures): | |
| try: | |
| result = future.result() | |
| if isinstance(result, dict) and "answers" in result: | |
| for j, answer in enumerate(result["answers"]): | |
| results_dict[start_idx + j] = answer | |
| else: | |
| for j in range(len(batch)): | |
| results_dict[start_idx + j] = "Error in response" | |
| except Exception as e: | |
| for j in range(len(batch)): | |
| results_dict[start_idx + j] = f"Error: {str(e)}" | |
| timing_data['llm_processing'] = round(time.time() - llm_start, 2) | |
| responses = [results_dict.get(i, "Not Found") for i in range(len(request.questions))] | |
| total_time = time.time() - start_time | |
| timing_data['total_time'] = round(total_time, 2) | |
| # Log to Supabase with user_agent + geo_location | |
| for q, a in zip(request.questions, responses): | |
| log_query( | |
| document_source=request.document_path or "UNKNOWN", | |
| question=q, | |
| answer=a, | |
| ip_address=user_ip, | |
| user_agent=user_agent, | |
| response_time=total_time | |
| ) | |
| return {"answers": responses} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port) | |