Spaces:
Sleeping
Sleeping
File size: 6,335 Bytes
15a08d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from langchain_pinecone import PineconeVectorStore
from src.config import Config
from src.helper import download_embeddings
from src.utility import QueryClassifier, StreamingHandler
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from src.prompt import system_prompt
import uuid
Config.validate()
PINECONE_API_KEY = Config.PINECONE_API_KEY
GEMINI_API_KEY = Config.GEMINI_API_KEY
templates = Jinja2Templates(directory="templates")
# Intialize FastAPI app
app = FastAPI(title="Medical Chatbot", version="0.0.0")
# Store for session-based chat histories (resets on server restart)
chat_histories = {}
# Intialize embedding model
print("Loading the Embedding model...")
embeddings = download_embeddings()
# Connect to existing Pinecone index
index_name = Config.PINECONE_INDEX_NAME
print(f"Connecting to PineCone index: {index_name}")
docsearch = PineconeVectorStore.from_existing_index(
index_name=index_name, embedding=embeddings
)
# Creating retriever from vector store
retriever = docsearch.as_retriever(
search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K}
)
# Initialize Google Gemini chat model
print("Initializing Gemini model...")
llm = ChatGoogleGenerativeAI(
model=Config.GEMINI_MODEL,
google_api_key=GEMINI_API_KEY,
temperature=Config.LLM_TEMPERATURE,
convert_system_message_to_human=True,
)
# Create chat prompt template with memory
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
]
)
# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)
# Create the RAG chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
# Function to get chat history for a session
def get_chat_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in chat_histories:
chat_histories[session_id] = ChatMessageHistory()
return chat_histories[session_id]
# Function to maintain conversation window buffer (keep last 5 messages)
def manage_memory_window(session_id: str, max_messages: int = 10):
"""Keep only the last max_messages (5 pairs = 10 messages)"""
if session_id in chat_histories:
history = chat_histories[session_id]
if len(history.messages) > max_messages:
# Keep only the last max_messages
history.messages = history.messages[-max_messages:]
print("Intialized Medical Chabot successfuly!")
print("Vector Store connected")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""Render the chatbot interface"""
# Clear all old sessions to prevent memory overflow
chat_histories.clear()
# Generate a new session ID for each page load
session_id = str(uuid.uuid4())
return templates.TemplateResponse(
"index.html", {"request": request, "session_id": session_id}
)
@app.post("/get")
async def chat(msg: str = Form(...), session_id: str = Form(...)):
"""Handle chat messages and return streaming AI responses with conversation memory"""
# Get chat history for this session
history = get_chat_history(session_id)
# Classify query to determine if retrieval is needed
needs_retrieval, reason = QueryClassifier.needs_retrieval(msg)
async def generate_response():
"""Generator for streaming response"""
full_answer = ""
try:
if needs_retrieval:
# Stream RAG chain response for medical queries
print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
async for chunk in StreamingHandler.stream_rag_response(
rag_chain, {"input": msg, "chat_history": history.messages}
):
yield chunk
# Extract full answer from the last chunk
if b'"done": true' in chunk.encode():
import json
data = json.loads(chunk.replace("data: ", "").strip())
if "full_answer" in data:
full_answer = data["full_answer"]
else:
# Stream simple response for greetings/acknowledgments
print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
simple_resp = QueryClassifier.get_simple_response(msg)
full_answer = simple_resp
async for chunk in StreamingHandler.stream_simple_response(simple_resp):
yield chunk
# Add the conversation to history after streaming completes
history.add_user_message(msg)
history.add_ai_message(full_answer)
# Manage memory window
manage_memory_window(session_id, max_messages=10)
except Exception as e:
print(f"Error during streaming: {str(e)}")
import json
yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
if __name__ == "__main__":
import uvicorn
import os
# Use PORT from environment (7860 for HF Spaces, 8080 for Render)
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|