AmaanP314 commited on
Commit
9852f5c
·
verified ·
1 Parent(s): f13ed75

initial commit

Browse files
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. app.py +29 -0
  3. chatbot.py +150 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ WORKDIR /code
3
+ COPY ./requirements.txt /code/requirements.txt
4
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
5
+
6
+ RUN useradd user
7
+ USER user
8
+
9
+ ENV HOME=/home/user \
10
+ PATH=/home/user/.local/bin:$PATH
11
+
12
+ WORKDIR $HOME/app
13
+
14
+ COPY --chown=user . $HOME/app
15
+
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from chatbot import chat
5
+
6
+ app = FastAPI()
7
+ app.add_middleware(
8
+ CORSMiddleware,
9
+ allow_origins=["*"],
10
+ allow_credentials=True,
11
+ allow_methods=["*"],
12
+ allow_headers=["*"],
13
+ )
14
+
15
+ class ChatRequest(BaseModel):
16
+ message: str
17
+ session_id: str
18
+
19
+ @app.post("/chat")
20
+ async def chat_endpoint(chat_req: ChatRequest):
21
+ message = chat_req.message.strip()
22
+ session_id = chat_req.session_id.strip()
23
+ if not message or not session_id:
24
+ return {"response": "Both message and session_id are required."}
25
+ try:
26
+ response = chat(message, session_id)
27
+ return {"response": response}
28
+ except Exception as e:
29
+ return {"response": f"Error: {str(e)}"}
chatbot.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
3
+ from pinecone import Pinecone
4
+ from pinecone_text.sparse import BM25Encoder
5
+ from typing import List, Dict, Any, Optional
6
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
7
+ from langchain_core.documents import Document
8
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
9
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
+ from langchain.chains import create_retrieval_chain
11
+ from langchain_core.messages import BaseMessage
12
+ from langchain_core.runnables import RunnableLambda
13
+ from langchain_community.chat_message_histories import ChatMessageHistory
14
+ from langchain_core.chat_history import BaseChatMessageHistory
15
+ from langchain_core.runnables.history import RunnableWithMessageHistory
16
+ from langchain.chains import create_history_aware_retriever
17
+ from langchain.chains.combine_documents import create_stuff_documents_chain
18
+ from dotenv import load_dotenv
19
+ load_dotenv()
20
+
21
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
22
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
23
+ embed_model = os.getenv("EMBEDDING_MODEL")
24
+ llm_model = os.getenv("LLM_MODEL")
25
+
26
+ embeddings = GoogleGenerativeAIEmbeddings(
27
+ google_api_key=GOOGLE_API_KEY,
28
+ model=embed_model
29
+ )
30
+ bm25_encoder = BM25Encoder().default()
31
+
32
+ index_name = "personal-assistant"
33
+ pc = Pinecone(api_key=PINECONE_API_KEY)
34
+ index = pc.Index(index_name)
35
+
36
+ class SafeHybridSearchRetriever(PineconeHybridSearchRetriever):
37
+ def _get_relevant_documents(
38
+ self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
39
+ ) -> List[Document]:
40
+ """Get documents relevant to the query using hybrid search with fallback to dense-only."""
41
+ try:
42
+ # Try hybrid search first
43
+ return super()._get_relevant_documents(query, run_manager=run_manager)
44
+ except Exception as e:
45
+ # If sparse encoding fails, fall back to dense-only search
46
+ if "Sparse vector must contain at least one value" in str(e):
47
+ print("Falling back to dense-only search for query:", query)
48
+ # Generate dense embeddings
49
+ embedding = self.embeddings.embed_query(query)
50
+ # Search with only dense vectors
51
+ results = self.index.query(
52
+ vector=embedding,
53
+ top_k=self.top_k,
54
+ include_metadata=True,
55
+ namespace=self.namespace,
56
+ )
57
+ # Convert Pinecone results to LangChain documents
58
+ return self._process_pinecone_results(results)
59
+ else:
60
+ # If it's a different error, re-raise it
61
+ raise e
62
+
63
+ def _process_pinecone_results(self, results):
64
+ """Process Pinecone results into Document objects."""
65
+ docs = []
66
+ for result in results.matches:
67
+ metadata = result.metadata or {}
68
+ # Create Document with page content and metadata
69
+ doc = Document(
70
+ page_content=metadata.pop("text", ""),
71
+ metadata=metadata,
72
+ )
73
+ docs.append(doc)
74
+ return docs
75
+
76
+ retriever = SafeHybridSearchRetriever(
77
+ embeddings=embeddings,
78
+ sparse_encoder=bm25_encoder,
79
+ index=index,
80
+ top_k=3
81
+ )
82
+
83
+ llm = ChatGoogleGenerativeAI(
84
+ model=llm_model,
85
+ google_api_key=GOOGLE_API_KEY,
86
+ temperature=1.0,
87
+ )
88
+
89
+ store = {}
90
+
91
+ def get_full_session_history(session_id: str) -> BaseChatMessageHistory:
92
+ if session_id not in store:
93
+ print(f"INFO: Creating new chat history for session: {session_id}")
94
+ store[session_id] = ChatMessageHistory()
95
+ return store[session_id]
96
+
97
+ MAX_HISTORY_TURNS = 3
98
+ MAX_HISTORY_MESSAGES = MAX_HISTORY_TURNS * 2
99
+
100
+
101
+ def limit_history_for_rag_chain(input_dict: Dict[str, Any]) -> Dict[str, Any]:
102
+ modified_input = input_dict.copy()
103
+
104
+ if "chat_history" in modified_input:
105
+ history = modified_input["chat_history"]
106
+ if isinstance(history, list) and all(isinstance(m, BaseMessage) for m in history):
107
+ limited_history = history[-MAX_HISTORY_MESSAGES:]
108
+ modified_input["chat_history"] = limited_history
109
+ else:
110
+ print("WARN: 'chat_history' in input_dict is not a list of BaseMessages. Passing as is.")
111
+ return modified_input
112
+
113
+ retriever_prompt_template = os.getenv("RETRIEVER_PROMPT").format(max_turns=MAX_HISTORY_TURNS)
114
+
115
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
116
+ [
117
+ ("system", retriever_prompt_template),
118
+ MessagesPlaceholder(variable_name="chat_history"), # This will receive the limited history
119
+ ("human", "{input}"),
120
+ ]
121
+ )
122
+ history_aware_retriever = create_history_aware_retriever(
123
+ llm, retriever, contextualize_q_prompt
124
+ )
125
+
126
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT").format(max_turns=MAX_HISTORY_TURNS)
127
+
128
+ qa_prompt = ChatPromptTemplate.from_messages(
129
+ [
130
+ ("system", SYSTEM_PROMPT),
131
+ ("human", "{input}"),
132
+ ]
133
+ )
134
+
135
+ qa_chain = create_stuff_documents_chain(llm, qa_prompt)
136
+ rag_chain = create_retrieval_chain(history_aware_retriever, qa_chain)
137
+
138
+ conversational_rag_chain = RunnableWithMessageHistory(
139
+ runnable=RunnableLambda(limit_history_for_rag_chain) | rag_chain,
140
+ get_session_history=get_full_session_history,
141
+ input_messages_key="input",
142
+ history_messages_key="chat_history",
143
+ output_messages_key="answer",
144
+ )
145
+ def chat(query: str, session_id: str):
146
+ response = conversational_rag_chain.invoke(
147
+ {"input": query},
148
+ config={"configurable": {"session_id": session_id}}
149
+ )
150
+ return response.get("answer", "No answer found.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pinecone>=6.0.0
2
+ pinecone-text>=0.1.0
3
+ langchain>=0.0.242
4
+ langchain-core>=0.0.3
5
+ langchain-community>=0.0.1
6
+ langchain-google-genai>=2.1.4
7
+ fastapi
8
+ uvicorn
9
+ python-dotenv