| | from fastapi import FastAPI, HTTPException, File, UploadFile, Form |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | from typing import Optional, List, Dict |
| | from pymongo import MongoClient |
| | from datetime import datetime |
| | import numpy as np |
| | import os |
| | from huggingface_hub import InferenceClient |
| |
|
| | from embedding_service import JinaClipEmbeddingService |
| | from qdrant_service import QdrantVectorService |
| |
|
| |
|
| | |
| | class ChatRequest(BaseModel): |
| | message: str |
| | use_rag: bool = True |
| | top_k: int = 3 |
| | system_message: Optional[str] = "You are a helpful AI assistant." |
| | max_tokens: int = 512 |
| | temperature: float = 0.7 |
| | top_p: float = 0.95 |
| | hf_token: Optional[str] = None |
| |
|
| |
|
| | class ChatResponse(BaseModel): |
| | response: str |
| | context_used: List[Dict] |
| | timestamp: str |
| |
|
| |
|
| | class AddDocumentRequest(BaseModel): |
| | text: str |
| | metadata: Optional[Dict] = None |
| |
|
| |
|
| | class AddDocumentResponse(BaseModel): |
| | success: bool |
| | doc_id: str |
| | message: str |
| |
|
| |
|
| | class SearchRequest(BaseModel): |
| | query: str |
| | top_k: int = 5 |
| | score_threshold: Optional[float] = 0.5 |
| |
|
| |
|
| | class SearchResponse(BaseModel): |
| | results: List[Dict] |
| |
|
| |
|
| | |
| | app = FastAPI( |
| | title="ChatbotRAG API", |
| | description="API for RAG Chatbot with GPT-OSS-20B + Jina CLIP v2 + MongoDB + Qdrant", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | |
| | class ChatbotRAGService: |
| | """ |
| | ChatbotRAG Service cho API |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | mongodb_uri: str = "mongodb+srv://truongtn7122003:7KaI9OT5KTUxWjVI@truongtn7122003.xogin4q.mongodb.net/", |
| | db_name: str = "chatbot_rag", |
| | collection_name: str = "documents", |
| | hf_token: Optional[str] = None |
| | ): |
| | print("Initializing ChatbotRAG Service...") |
| |
|
| | |
| | self.mongo_client = MongoClient(mongodb_uri) |
| | self.db = self.mongo_client[db_name] |
| | self.documents_collection = self.db[collection_name] |
| | self.chat_history_collection = self.db["chat_history"] |
| |
|
| | |
| | self.embedding_service = JinaClipEmbeddingService( |
| | model_path="jinaai/jina-clip-v2" |
| | ) |
| |
|
| | |
| | collection_name = os.getenv("COLLECTION_NAME","event_social_media") |
| | self.qdrant_service = QdrantVectorService( |
| | collection_name= collection_name, |
| | vector_size=self.embedding_service.get_embedding_dimension() |
| | ) |
| |
|
| | |
| | self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") |
| | if self.hf_token: |
| | print("✓ Hugging Face token configured") |
| | else: |
| | print("⚠ No Hugging Face token - LLM generation will use placeholder") |
| |
|
| | print("✓ ChatbotRAG Service initialized") |
| |
|
| | def add_document(self, text: str, metadata: Dict = None) -> str: |
| | """Add document to knowledge base""" |
| | |
| | doc_data = { |
| | "text": text, |
| | "metadata": metadata or {}, |
| | "created_at": datetime.utcnow() |
| | } |
| | result = self.documents_collection.insert_one(doc_data) |
| | doc_id = str(result.inserted_id) |
| |
|
| | |
| | embedding = self.embedding_service.encode_text(text) |
| |
|
| | |
| | self.qdrant_service.index_data( |
| | doc_id=doc_id, |
| | embedding=embedding, |
| | metadata={ |
| | "text": text, |
| | "source": "api", |
| | **(metadata or {}) |
| | } |
| | ) |
| |
|
| | return doc_id |
| |
|
| | def retrieve_context(self, query: str, top_k: int = 3, score_threshold: float = 0.5) -> List[Dict]: |
| | """Retrieve relevant context from vector DB""" |
| | |
| | query_embedding = self.embedding_service.encode_text(query) |
| |
|
| | |
| | results = self.qdrant_service.search( |
| | query_embedding=query_embedding, |
| | limit=top_k, |
| | score_threshold=score_threshold |
| | ) |
| |
|
| | return results |
| |
|
| | def generate_response( |
| | self, |
| | message: str, |
| | context: List[Dict], |
| | system_message: str, |
| | max_tokens: int = 512, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | hf_token: Optional[str] = None |
| | ) -> str: |
| | """ |
| | Generate response using Hugging Face LLM |
| | """ |
| | |
| | context_text = "" |
| | if context: |
| | context_text = "\n\nRelevant Context:\n" |
| | for i, doc in enumerate(context, 1): |
| | doc_text = doc["metadata"].get("text", "") |
| | confidence = doc["confidence"] |
| | context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n" |
| |
|
| | |
| | system_message = f"{system_message}\n{context_text}\n\nPlease use the above context to answer the user's question when relevant." |
| |
|
| | |
| | token = hf_token or self.hf_token |
| |
|
| | |
| | if not token: |
| | return f"""[LLM Response Placeholder] |
| | |
| | Context retrieved: {len(context)} documents |
| | User question: {message} |
| | |
| | To enable actual LLM generation: |
| | 1. Set HUGGINGFACE_TOKEN environment variable, OR |
| | 2. Pass hf_token in request body |
| | |
| | Example: |
| | {{ |
| | "message": "Your question", |
| | "hf_token": "hf_xxxxxxxxxxxxx" |
| | }} |
| | """ |
| |
|
| | |
| | try: |
| | client = InferenceClient( |
| | token=token, |
| | model="openai/gpt-oss-20b" |
| | ) |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": system_message}, |
| | {"role": "user", "content": message} |
| | ] |
| |
|
| | |
| | response = "" |
| | for msg in client.chat_completion( |
| | messages, |
| | max_tokens=max_tokens, |
| | stream=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ): |
| | choices = msg.choices |
| | if len(choices) and choices[0].delta.content: |
| | response += choices[0].delta.content |
| |
|
| | return response |
| |
|
| | except Exception as e: |
| | return f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed." |
| |
|
| | def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]): |
| | """Save chat to MongoDB""" |
| | chat_data = { |
| | "user_message": user_message, |
| | "assistant_response": assistant_response, |
| | "context_used": context_used, |
| | "timestamp": datetime.utcnow() |
| | } |
| | self.chat_history_collection.insert_one(chat_data) |
| |
|
| | def get_stats(self) -> Dict: |
| | """Get statistics""" |
| | return { |
| | "documents_count": self.documents_collection.count_documents({}), |
| | "chat_history_count": self.chat_history_collection.count_documents({}), |
| | "qdrant_info": self.qdrant_service.get_collection_info() |
| | } |
| |
|
| |
|
| | |
| | rag_service = ChatbotRAGService() |
| |
|
| |
|
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Health check""" |
| | return { |
| | "status": "running", |
| | "service": "ChatbotRAG API", |
| | "version": "1.0.0", |
| | "endpoints": { |
| | "POST /chat": "Chat with RAG", |
| | "POST /documents": "Add document to knowledge base", |
| | "POST /search": "Search in knowledge base", |
| | "GET /stats": "Get statistics", |
| | "GET /history": "Get chat history" |
| | } |
| | } |
| |
|
| |
|
| | @app.post("/chat", response_model=ChatResponse) |
| | async def chat(request: ChatRequest): |
| | """ |
| | Chat endpoint with RAG |
| | |
| | Body: |
| | - message: User message |
| | - use_rag: Enable RAG retrieval (default: true) |
| | - top_k: Number of documents to retrieve (default: 3) |
| | - system_message: System prompt (optional) |
| | - max_tokens: Max tokens for response (default: 512) |
| | - temperature: Temperature for generation (default: 0.7) |
| | |
| | Returns: |
| | - response: Generated response |
| | - context_used: Retrieved context documents |
| | - timestamp: Response timestamp |
| | """ |
| | try: |
| | |
| | context_used = [] |
| | if request.use_rag: |
| | context_used = rag_service.retrieve_context( |
| | query=request.message, |
| | top_k=request.top_k |
| | ) |
| |
|
| | |
| | response = rag_service.generate_response( |
| | message=request.message, |
| | context=context_used, |
| | system_message=request.system_message, |
| | max_tokens=request.max_tokens, |
| | temperature=request.temperature, |
| | top_p=request.top_p, |
| | hf_token=request.hf_token |
| | ) |
| |
|
| | |
| | rag_service.save_chat_history( |
| | user_message=request.message, |
| | assistant_response=response, |
| | context_used=context_used |
| | ) |
| |
|
| | return ChatResponse( |
| | response=response, |
| | context_used=context_used, |
| | timestamp=datetime.utcnow().isoformat() |
| | ) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | @app.post("/documents", response_model=AddDocumentResponse) |
| | async def add_document(request: AddDocumentRequest): |
| | """ |
| | Add document to knowledge base |
| | |
| | Body: |
| | - text: Document text |
| | - metadata: Additional metadata (optional) |
| | |
| | Returns: |
| | - success: True/False |
| | - doc_id: MongoDB document ID |
| | - message: Status message |
| | """ |
| | try: |
| | doc_id = rag_service.add_document( |
| | text=request.text, |
| | metadata=request.metadata |
| | ) |
| |
|
| | return AddDocumentResponse( |
| | success=True, |
| | doc_id=doc_id, |
| | message=f"Document added successfully with ID: {doc_id}" |
| | ) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | @app.post("/search", response_model=SearchResponse) |
| | async def search(request: SearchRequest): |
| | """ |
| | Search in knowledge base |
| | |
| | Body: |
| | - query: Search query |
| | - top_k: Number of results (default: 5) |
| | - score_threshold: Minimum score (default: 0.5) |
| | |
| | Returns: |
| | - results: List of matching documents |
| | """ |
| | try: |
| | results = rag_service.retrieve_context( |
| | query=request.query, |
| | top_k=request.top_k, |
| | score_threshold=request.score_threshold |
| | ) |
| |
|
| | return SearchResponse(results=results) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | @app.get("/stats") |
| | async def get_stats(): |
| | """ |
| | Get statistics |
| | |
| | Returns: |
| | - documents_count: Number of documents in MongoDB |
| | - chat_history_count: Number of chat messages |
| | - qdrant_info: Qdrant collection info |
| | """ |
| | try: |
| | return rag_service.get_stats() |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | @app.get("/history") |
| | async def get_history(limit: int = 10, skip: int = 0): |
| | """ |
| | Get chat history |
| | |
| | Query params: |
| | - limit: Number of messages to return (default: 10) |
| | - skip: Number of messages to skip (default: 0) |
| | |
| | Returns: |
| | - history: List of chat messages |
| | """ |
| | try: |
| | history = list( |
| | rag_service.chat_history_collection |
| | .find({}, {"_id": 0}) |
| | .sort("timestamp", -1) |
| | .skip(skip) |
| | .limit(limit) |
| | ) |
| |
|
| | |
| | for msg in history: |
| | if "timestamp" in msg: |
| | msg["timestamp"] = msg["timestamp"].isoformat() |
| |
|
| | return {"history": history, "total": rag_service.chat_history_collection.count_documents({})} |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | @app.delete("/documents/{doc_id}") |
| | async def delete_document(doc_id: str): |
| | """ |
| | Delete document from knowledge base |
| | |
| | Args: |
| | - doc_id: Document ID (MongoDB ObjectId) |
| | |
| | Returns: |
| | - success: True/False |
| | - message: Status message |
| | """ |
| | try: |
| | |
| | result = rag_service.documents_collection.delete_one({"_id": doc_id}) |
| |
|
| | |
| | if result.deleted_count > 0: |
| | rag_service.qdrant_service.delete_by_id(doc_id) |
| | return {"success": True, "message": f"Document {doc_id} deleted"} |
| | else: |
| | raise HTTPException(status_code=404, detail=f"Document {doc_id} not found") |
| |
|
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run( |
| | app, |
| | host="0.0.0.0", |
| | port=8000, |
| | log_level="info" |
| | ) |
| |
|