Spaces:
Running
Running
| import asyncio | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import uuid | |
| import uvicorn | |
| from typing import List, Optional | |
| import traceback | |
| # Import your existing modules | |
| from database import db | |
| from models import ChatMessage, ChatRequest, ChatResponse, Product, SearchRequest, Conversation, KnowledgeDocument, Document, SourceInfo | |
| from config import settings | |
| from rag_system import rag_pipeline | |
| app = FastAPI( | |
| title="RAG Chatbot API", | |
| description="Lightweight RAG Chatbot using MongoDB Atlas and Gemini", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| embeddings_generated = False | |
| async def startup_event(): | |
| """Run on application startup - WITHOUT embedding generation""" | |
| try: | |
| print("π Starting RAG Chatbot API...") | |
| # Initialize database connection | |
| await db.connect() | |
| # Check database status (but don't generate embeddings) | |
| stats = await db.get_collection_stats() | |
| print(f"π Database status: {stats}") | |
| if stats["documents_with_embeddings"] == 0: | |
| print("β οΈ No embeddings found in database. Please pre-compute embeddings separately.") | |
| print("π‘ Run the embedding generation script locally and upload to MongoDB Atlas.") | |
| else: | |
| print(f"β Ready! Using {stats['documents_with_embeddings']} documents with embeddings from MongoDB Atlas") | |
| print("β RAG Chatbot API is ready!") | |
| except Exception as e: | |
| print(f"β Startup error: {e}") | |
| raise | |
| async def root(): | |
| return {"message": "RAG Chatbot API is running!", "status": "healthy"} | |
| async def health_check(): | |
| return {"status": "healthy", "service": "rag-chatbot"} | |
| async def chat_with_assistant(request: ChatRequest): | |
| """Main chat endpoint for product queries""" | |
| try: | |
| print(f"π¬ Received chat request: {request.message}") | |
| response, sources = await rag_pipeline.generate_response(request.message) | |
| suggested_questions = rag_pipeline.generate_followup_questions( | |
| request.message, | |
| sources | |
| ) | |
| # Convert to SourceInfo objects | |
| source_objects = [] | |
| for product in sources: | |
| source_objects.append(SourceInfo( | |
| id=product.get("id", ""), | |
| name=product.get("source", "Product"), | |
| category=product.get("metadata", {}).get("category", "N/A"), | |
| price=str(product.get("metadata", {}).get("price", "N/A")), | |
| similarity_score=product.get("metadata", {}).get("similarity_score", 0) | |
| )) | |
| return ChatResponse( | |
| response=response, | |
| sources=source_objects, | |
| suggested_questions=suggested_questions, | |
| conversation_id=request.conversation_id | |
| ) | |
| except Exception as e: | |
| print(f"β Error in /chat endpoint: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def debug_database_stats(): | |
| """Get detailed database statistics""" | |
| try: | |
| stats = await db.get_collection_stats() | |
| # Sample some documents to see their structure | |
| sample_docs = [] | |
| cursor = db.collection.find({"embedding": {"$exists": True}}).limit(3) | |
| async for doc in cursor: | |
| sample_docs.append({ | |
| "id": str(doc["_id"]), | |
| "title": doc.get("title", "N/A"), | |
| "category": doc.get("category", "N/A"), | |
| "has_embedding": "embedding" in doc, | |
| "embedding_length": len(doc.get("embedding", [])), | |
| "content_preview": f"{doc.get('title', '')} - {doc.get('product_description', '')[:50]}..." | |
| }) | |
| return { | |
| "database": settings.DATABASE_NAME, | |
| "collection": settings.COLLECTION_NAME, | |
| "statistics": stats, | |
| "sample_documents_with_embeddings": sample_docs | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) |