File size: 6,335 Bytes
15a08d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from langchain_pinecone import PineconeVectorStore
from src.config import Config
from src.helper import download_embeddings
from src.utility import QueryClassifier, StreamingHandler
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from src.prompt import system_prompt
import uuid


Config.validate()


PINECONE_API_KEY = Config.PINECONE_API_KEY
GEMINI_API_KEY = Config.GEMINI_API_KEY

templates = Jinja2Templates(directory="templates")

# Intialize FastAPI app
app = FastAPI(title="Medical Chatbot", version="0.0.0")

# Store for session-based chat histories (resets on server restart)
chat_histories = {}

# Intialize embedding model
print("Loading the Embedding model...")
embeddings = download_embeddings()

# Connect to existing Pinecone index
index_name = Config.PINECONE_INDEX_NAME
print(f"Connecting to PineCone index: {index_name}")
docsearch = PineconeVectorStore.from_existing_index(
    index_name=index_name, embedding=embeddings
)

# Creating retriever from vector store
retriever = docsearch.as_retriever(
    search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K}
)

# Initialize Google Gemini chat model
print("Initializing Gemini model...")
llm = ChatGoogleGenerativeAI(
    model=Config.GEMINI_MODEL,
    google_api_key=GEMINI_API_KEY,
    temperature=Config.LLM_TEMPERATURE,
    convert_system_message_to_human=True,
)

# Create chat prompt template with memory
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
    ]
)

# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)

# Create the RAG chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)


# Function to get chat history for a session
def get_chat_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in chat_histories:
        chat_histories[session_id] = ChatMessageHistory()
    return chat_histories[session_id]


# Function to maintain conversation window buffer (keep last 5 messages)
def manage_memory_window(session_id: str, max_messages: int = 10):
    """Keep only the last max_messages (5 pairs = 10 messages)"""
    if session_id in chat_histories:
        history = chat_histories[session_id]
        if len(history.messages) > max_messages:
            # Keep only the last max_messages
            history.messages = history.messages[-max_messages:]


print("Intialized Medical Chabot successfuly!")
print("Vector Store connected")


@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    """Render the chatbot interface"""
    # Clear all old sessions to prevent memory overflow
    chat_histories.clear()

    # Generate a new session ID for each page load
    session_id = str(uuid.uuid4())
    return templates.TemplateResponse(
        "index.html", {"request": request, "session_id": session_id}
    )


@app.post("/get")
async def chat(msg: str = Form(...), session_id: str = Form(...)):
    """Handle chat messages and return streaming AI responses with conversation memory"""
    
    # Get chat history for this session
    history = get_chat_history(session_id)

    # Classify query to determine if retrieval is needed
    needs_retrieval, reason = QueryClassifier.needs_retrieval(msg)

    async def generate_response():
        """Generator for streaming response"""
        full_answer = ""
        
        try:
            if needs_retrieval:
                # Stream RAG chain response for medical queries
                print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
                
                async for chunk in StreamingHandler.stream_rag_response(
                    rag_chain, {"input": msg, "chat_history": history.messages}
                ):
                    yield chunk
                    # Extract full answer from the last chunk
                    if b'"done": true' in chunk.encode():
                        import json
                        data = json.loads(chunk.replace("data: ", "").strip())
                        if "full_answer" in data:
                            full_answer = data["full_answer"]
            else:
                # Stream simple response for greetings/acknowledgments
                print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
                simple_resp = QueryClassifier.get_simple_response(msg)
                full_answer = simple_resp
                
                async for chunk in StreamingHandler.stream_simple_response(simple_resp):
                    yield chunk
            
            # Add the conversation to history after streaming completes
            history.add_user_message(msg)
            history.add_ai_message(full_answer)
            
            # Manage memory window
            manage_memory_window(session_id, max_messages=10)
            
        except Exception as e:
            print(f"Error during streaming: {str(e)}")
            import json
            yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n"

    return StreamingResponse(
        generate_response(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"
        }
    )


if __name__ == "__main__":
    import uvicorn
    import os

    # Use PORT from environment (7860 for HF Spaces, 8080 for Render)
    port = int(os.getenv("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)