Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from dotenv import load_dotenv | |
| import sys | |
| import os | |
| # -------------------------------------------------------- | |
| # Fix Python module paths | |
| # -------------------------------------------------------- | |
| current_dir = os.path.dirname(__file__) | |
| parent_dir = os.path.dirname(current_dir) | |
| sys.path.insert(0, current_dir) # For embeddings + database | |
| sys.path.insert(0, os.path.join(parent_dir, "api")) # For utils | |
| # -------------------------------------------------------- | |
| # Imports AFTER adjusting paths | |
| # -------------------------------------------------------- | |
| from embeddings import embed_text | |
| from database import insert_document_chunks, search_vectors, list_all_documents, initialize_database | |
| from utils.text_extractor import extract_text | |
| # -------------------------------------------------------- | |
| # Load environment variables | |
| # -------------------------------------------------------- | |
| load_dotenv() | |
| # -------------------------------------------------------- | |
| # FastAPI App | |
| # -------------------------------------------------------- | |
| app = FastAPI( | |
| title="RAG MCP Server", | |
| description="Provides semantic search + ingestion for tenant knowledge bases", | |
| version="1.0.0" | |
| ) | |
| # -------------------------------------------------------- | |
| # Enable CORS | |
| # -------------------------------------------------------- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------------------------------------------------------- | |
| # Startup Event - Initialize Database | |
| # -------------------------------------------------------- | |
| async def startup_event(): | |
| """Initialize database schema on server startup.""" | |
| try: | |
| print("Initializing database schema...") | |
| initialize_database() | |
| except Exception as e: | |
| print(f"Warning: Database initialization failed: {e}") | |
| print("Server will continue, but database operations may fail.") | |
| # -------------------------------------------------------- | |
| # Request Models | |
| # -------------------------------------------------------- | |
| class IngestPayload(BaseModel): | |
| tenant_id: str | |
| content: str | |
| class SearchPayload(BaseModel): | |
| query: str | |
| tenant_id: str | |
| # -------------------------------------------------------- | |
| # Health Check | |
| # -------------------------------------------------------- | |
| def root(): | |
| return {"status": "RAG MCP SERVER RUNNING"} | |
| # -------------------------------------------------------- | |
| # Ingest Route | |
| # -------------------------------------------------------- | |
| def ingest(payload: IngestPayload): | |
| """ | |
| Ingest raw text: | |
| - Chunk text | |
| - Embed chunks | |
| - Store in Postgres | |
| """ | |
| try: | |
| chunks = extract_text(payload.content) | |
| if not chunks: | |
| raise HTTPException(400, "No text found to ingest.") | |
| inserted = 0 | |
| for chunk in chunks: | |
| embedding = embed_text(chunk) | |
| insert_document_chunks(payload.tenant_id, chunk, embedding) | |
| inserted += 1 | |
| return { | |
| "status": "ok", | |
| "tenant_id": payload.tenant_id, | |
| "chunks_stored": inserted | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------------------------------------------------------- | |
| # Search Route | |
| # -------------------------------------------------------- | |
| def search(payload: SearchPayload): | |
| """ | |
| Semantic search using pgvector + MiniLM embeddings. | |
| """ | |
| try: | |
| query_embedding = embed_text(payload.query) | |
| results = search_vectors(payload.tenant_id, query_embedding, limit=5) | |
| return { | |
| "tenant_id": payload.tenant_id, | |
| "query": payload.query, | |
| "results": results | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------------------------------------------------------- | |
| # List All Documents Route | |
| # -------------------------------------------------------- | |
| def list_documents(tenant_id: str, limit: int = 1000, offset: int = 0): | |
| """ | |
| List all documents for a tenant with pagination. | |
| """ | |
| try: | |
| result = list_all_documents(tenant_id, limit=limit, offset=offset) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------------------------------------------------------- | |
| # Allow "python main.py" to start server | |
| # -------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting RAG MCP Server on http://0.0.0.0:8001") | |
| print("API Documentation: http://localhost:8001/docs") | |
| print("Note: Reload mode disabled when running directly") | |
| # Run the app directly (reload doesn't work with app object) | |
| uvicorn.run( | |
| app, # Pass the app object directly | |
| host="0.0.0.0", | |
| port=8001, | |
| reload=False # Reload requires module path, not app object | |
| ) | |