MusaR commited on
Commit
dcc1b8b
Β·
verified Β·
1 Parent(s): 9b86fc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -115
app.py CHANGED
@@ -1,38 +1,33 @@
1
- # app.py (DEBUGGING VERSION)
2
 
3
  print("--- Python script starting ---")
4
-
5
- import streamlit as st
6
  import os
7
-
8
- import langchain
9
- langchain.debug = True
10
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
- os.environ['HF_HOME'] = '/app/huggingface_cache' # For transformers and datasets
12
  os.environ['TRANSFORMERS_CACHE'] = '/app/huggingface_cache/transformers'
13
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/huggingface_cache/sentence_transformers'
14
- # Create the directory if it doesn't exist, with permissions
15
  if not os.path.exists('/app/huggingface_cache'):
16
  os.makedirs('/app/huggingface_cache', exist_ok=True)
 
 
 
 
 
17
  from dotenv import load_dotenv
18
  from pinecone import Pinecone
19
 
20
- # --- Standard Imports ---
21
  from langchain_pinecone import PineconeVectorStore
22
  from langchain_community.embeddings import SentenceTransformerEmbeddings
23
  from langchain_groq import ChatGroq
24
- from langchain_core.prompts import PromptTemplate
25
  from langchain_core.runnables import RunnablePassthrough
26
- from langchain_core.output_parsers import PydanticOutputParser
27
- from pydantic import BaseModel, Field
28
  from langchain.retrievers import ContextualCompressionRetriever
29
  from langchain.retrievers.document_compressors import CohereRerank
30
 
31
  print("--- All imports successful ---")
32
 
33
- # We wrap the ENTIRE app in a try/except block to catch any startup error
34
  try:
35
- # --- Load Environment Variables ---
36
  print("Step 1: Loading environment variables...")
37
  load_dotenv()
38
  PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
@@ -41,47 +36,23 @@ try:
41
  INDEX_NAME = "rag-chatbot"
42
  print("Step 1: SUCCESS")
43
 
44
- # --- Page Configuration ---
45
- st.set_page_config(page_title="Production RAG System", page_icon="πŸš€", layout="wide")
46
- st.title("πŸš€ Production-Grade RAG System")
47
 
48
- # --- Pydantic Model ---
49
- class StructuredAnswer(BaseModel):
50
- summary: str = Field(description="A concise summary.")
51
- key_points: list[str] = Field(description="A list of key bullet points.")
52
- confidence_score: float = Field(description="A 0.0 to 1.0 confidence score.")
53
-
54
- # --- Caching and Initialization ---
55
  @st.cache_resource
56
  def initialize_services():
57
  print("Step 2: Entering initialize_services function...")
58
  if not all([PINECONE_API_KEY, GROQ_API_KEY, COHERE_API_KEY]):
59
  raise ValueError("An API key is missing!")
60
-
61
- print("Step 2a: Initializing embedding model...")
62
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
63
- print("Step 2a: SUCCESS")
64
-
65
- print("Step 2b: Initializing Pinecone client...")
66
  pinecone = Pinecone(api_key=PINECONE_API_KEY)
67
- host = "https://rag-chatbot-sg8t88c.svc.aped-4627-b74a.pinecone.io"
68
  index = pinecone.Index(host=host)
69
- print("Step 2b: SUCCESS")
70
-
71
- print("Step 2c: Creating PineconeVectorStore object...")
72
  vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
73
- print("Step 2c: SUCCESS")
74
-
75
- print("Step 2d: Initializing Cohere Re-ranker...")
76
- base_retriever = vectorstore.as_retriever(search_kwargs={'k': 20})
77
- compressor = CohereRerank(cohere_api_key=COHERE_API_KEY, top_n=5, model="rerank-english-v3.0")
78
  reranking_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
79
- print("Step 2d: SUCCESS")
80
-
81
- print("Step 2e: Initializing Groq LLM...")
82
- llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
83
- print("Step 2e: SUCCESS")
84
-
85
  print("Step 2: All services initialized successfully.")
86
  return reranking_retriever, llm
87
 
@@ -89,92 +60,96 @@ try:
89
  retriever, llm = initialize_services()
90
  print("Step 3: SUCCESS, services are loaded.")
91
 
92
- # --- RAG Chain Definition ---
93
  print("Step 4: Defining RAG chain...")
94
- pydantic_parser = PydanticOutputParser(pydantic_object=StructuredAnswer)
95
- format_instructions = pydantic_parser.get_format_instructions()
96
- template = """
97
- You are a world-class analysis engine. Your task is to provide a structured, factual answer based *only* on the following context.
98
- Synthesize the information from all context snippets. Do not use any outside knowledge.
 
99
 
100
  Context:
101
  {context}
102
-
103
- Question:
104
- {question}
105
-
106
- Follow these formatting instructions precisely:
107
- {format_instructions}
108
  """
109
- prompt = PromptTemplate(
110
- template=template,
111
- input_variables=["context", "question"],
112
- partial_variables={"format_instructions": format_instructions}
113
- )
114
-
115
- # --- NEW: Break down the chain for debugging ---
116
- def retrieve_and_rerank(input_dict):
117
- print(f"--- RAG DEBUG: Retrieving for question: {input_dict['question']} ---")
118
- docs = retriever.invoke(input_dict['question'])
119
- print(f"--- RAG DEBUG: Retrieved {len(docs)} docs after reranking ---")
120
- for i, doc in enumerate(docs):
121
- print(f" Doc {i} (source: {doc.metadata.get('source', 'N/A')}, page: {doc.metadata.get('page', 'N/A')}): {doc.page_content[:100]}...")
122
- return {"context": docs, "question": input_dict['question']}
123
-
124
- def format_prompt(input_dict):
125
- print(f"--- RAG DEBUG: Formatting prompt with context ---")
126
- # Manually construct the context string to see it clearly
127
- context_str = "\n\n---\n\n".join([doc.page_content for doc in input_dict['context']])
128
- print(f"--- RAG DEBUG: Context fed to LLM: {context_str[:500]}... ---") # Print first 500 chars of context
129
- return prompt.invoke({"context": context_str, "question": input_dict['question']})
130
 
131
- def call_llm(formatted_prompt):
132
- print(f"--- RAG DEBUG: Calling LLM ---")
133
- llm_output = llm.invoke(formatted_prompt)
134
- print(f"--- RAG DEBUG: Raw LLM Output: {llm_output} ---") # See exactly what Groq returns
135
- return llm_output
136
-
137
- def parse_output(llm_output_str):
138
- print(f"--- RAG DEBUG: Attempting to parse LLM output with Pydantic ---")
139
- try:
140
- parsed = pydantic_parser.invoke(llm_output_str)
141
- print(f"--- RAG DEBUG: Pydantic parsing successful ---")
142
- return parsed
143
- except Exception as e_parse:
144
- print(f"!!!!!!!!!! PYDANTIC PARSING ERROR !!!!!!!!!!")
145
- print(f"Raw LLM Output that failed to parse: {llm_output_str}")
146
- print(traceback.format_exc())
147
- # Fallback: return a dictionary indicating failure, or just the raw string
148
- return StructuredAnswer(summary="LLM output parsing failed. See logs.", key_points=[], confidence_score=0.0)
149
-
150
-
151
-
152
-
153
 
154
-
155
  rag_chain = (
156
- {"context": retriever, "question": RunnablePassthrough()}
157
  | prompt
158
  | llm
159
- | pydantic_parser
160
  )
161
  print("Step 4: SUCCESS")
162
 
163
- # --- UI Rendering ---
164
- print("Step 5: Starting to render Streamlit UI...")
165
- st.success("System is ready. Ask your question below.")
166
- query = st.text_input("Enter your question:", key="query_input")
167
-
168
- if query:
169
- with st.spinner("Processing..."):
170
- structured_answer = rag_chain.invoke(query)
171
- st.write("### Answer")
172
- # ... rest of UI ...
173
- print("Step 5: SUCCESS, UI is rendered.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  except Exception as e:
176
- # If ANY error happens during startup, it will be printed here
177
- print(f"!!!!!!!!!! A FATAL ERROR OCCURRED !!!!!!!!!!")
178
  import traceback
179
  print(traceback.format_exc())
180
  st.error(f"A fatal error occurred during startup. Please check the container logs. Error: {e}")
 
1
+ %%writefile app.py
2
 
3
  print("--- Python script starting ---")
 
 
4
  import os
 
 
 
5
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+ os.environ['HF_HOME'] = '/app/huggingface_cache'
7
  os.environ['TRANSFORMERS_CACHE'] = '/app/huggingface_cache/transformers'
8
  os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/huggingface_cache/sentence_transformers'
 
9
  if not os.path.exists('/app/huggingface_cache'):
10
  os.makedirs('/app/huggingface_cache', exist_ok=True)
11
+
12
+ import langchain
13
+ langchain.debug = False # Turn off verbose RAG chain logging for production
14
+
15
+ import streamlit as st
16
  from dotenv import load_dotenv
17
  from pinecone import Pinecone
18
 
 
19
  from langchain_pinecone import PineconeVectorStore
20
  from langchain_community.embeddings import SentenceTransformerEmbeddings
21
  from langchain_groq import ChatGroq
22
+ from langchain_core.prompts import ChatPromptTemplate # Use ChatPromptTemplate
23
  from langchain_core.runnables import RunnablePassthrough
24
+ from langchain_core.output_parsers import StrOutputParser # Simpler string output
 
25
  from langchain.retrievers import ContextualCompressionRetriever
26
  from langchain.retrievers.document_compressors import CohereRerank
27
 
28
  print("--- All imports successful ---")
29
 
 
30
  try:
 
31
  print("Step 1: Loading environment variables...")
32
  load_dotenv()
33
  PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
 
36
  INDEX_NAME = "rag-chatbot"
37
  print("Step 1: SUCCESS")
38
 
39
+ st.set_page_config(page_title="Advanced RAG Chatbot", page_icon="πŸš€", layout="wide")
40
+ st.title("πŸš€ Production-Grade RAG Chatbot")
 
41
 
 
 
 
 
 
 
 
42
  @st.cache_resource
43
  def initialize_services():
44
  print("Step 2: Entering initialize_services function...")
45
  if not all([PINECONE_API_KEY, GROQ_API_KEY, COHERE_API_KEY]):
46
  raise ValueError("An API key is missing!")
 
 
47
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
 
48
  pinecone = Pinecone(api_key=PINECONE_API_KEY)
49
+ host = "https://rag-chatbot-sg8t88c.svc.aped-4627-b74a.pinecone.io" # Your host
50
  index = pinecone.Index(host=host)
 
 
 
51
  vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
52
+ base_retriever = vectorstore.as_retriever(search_kwargs={'k': 10}) # Fetch 10 for reranker
53
+ compressor = CohereRerank(cohere_api_key=COHERE_API_KEY, top_n=3, model="rerank-english-02") # Rerank to top 3
 
 
 
54
  reranking_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
55
+ llm = ChatGroq(temperature=0.1, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
 
 
 
 
 
56
  print("Step 2: All services initialized successfully.")
57
  return reranking_retriever, llm
58
 
 
60
  retriever, llm = initialize_services()
61
  print("Step 3: SUCCESS, services are loaded.")
62
 
63
+ # --- NEW RAG CHAIN with simpler output and source handling ---
64
  print("Step 4: Defining RAG chain...")
65
+
66
+ # System prompt to guide the LLM for chat-like, sourced answers
67
+ system_prompt = """You are a helpful AI assistant that answers questions based ONLY on the provided context.
68
+ Your answer should be concise and directly address the question.
69
+ After your answer, list the numbers of the sources you used, like this: [1][2].
70
+ Do not make up information. If the answer is not in the context, say "I cannot answer this based on the provided documents."
71
 
72
  Context:
73
  {context}
 
 
 
 
 
 
74
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ prompt = ChatPromptTemplate.from_messages([
77
+ ("system", system_prompt),
78
+ ("human", "{question}")
79
+ ])
80
+
81
+ def format_docs_with_numbers(docs):
82
+ # Prepend numbers to each document for citation
83
+ # Also limit the length of each doc to avoid overwhelming the LLM
84
+ MAX_DOC_LENGTH = 1500 # Max characters per document chunk
85
+ numbered_docs = []
86
+ for i, doc in enumerate(docs):
87
+ content = doc.page_content
88
+ if len(content) > MAX_DOC_LENGTH:
89
+ content = content[:MAX_DOC_LENGTH] + "..."
90
+ numbered_docs.append(f"Source [{i+1}]:\n{content}")
91
+ return "\n\n".join(numbered_docs)
 
 
 
 
 
 
92
 
 
93
  rag_chain = (
94
+ {"context": retriever | format_docs_with_numbers, "question": RunnablePassthrough()}
95
  | prompt
96
  | llm
97
+ | StrOutputParser()
98
  )
99
  print("Step 4: SUCCESS")
100
 
101
+ # --- Initialize chat history ---
102
+ if "messages" not in st.session_state:
103
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! I'm ready to answer questions about your documents."}]
104
+
105
+ # Display chat messages
106
+ for message in st.session_state.messages:
107
+ with st.chat_message(message["role"]):
108
+ st.markdown(message["content"])
109
+
110
+ # Chat input
111
+ if user_query := st.chat_input("Ask a question about your documents"):
112
+ st.session_state.messages.append({"role": "user", "content": user_query})
113
+ with st.chat_message("user"):
114
+ st.markdown(user_query)
115
+
116
+ with st.chat_message("assistant"):
117
+ with st.spinner("Thinking..."):
118
+ try:
119
+ print(f"--- UI DEBUG: Invoking RAG chain with query: {user_query} ---")
120
+ answer = rag_chain.invoke(user_query)
121
+ print(f"--- UI DEBUG: Raw LLM Answer: {answer} ---")
122
+
123
+ st.markdown(answer) # Display the LLM's answer directly
124
+
125
+ # Retrieve sources again just for display (not ideal for performance but simple)
126
+ # In a more complex app, you'd pass source objects through the chain.
127
+ with st.expander("Sources"):
128
+ source_docs = retriever.invoke(user_query)
129
+ if source_docs:
130
+ for i, doc in enumerate(source_docs):
131
+ source_filename = os.path.basename(doc.metadata.get('source', 'Unknown'))
132
+ page_number = doc.metadata.get('page', 'N/A')
133
+ st.markdown(f"**[{i+1}] Source:** `{source_filename}` (Page: {page_number})")
134
+ st.markdown(f"> {doc.page_content[:300]}...") # Show a snippet
135
+ st.markdown("---")
136
+ else:
137
+ st.write("No specific sources were retrieved for this part of the answer.")
138
+
139
+ st.session_state.messages.append({"role": "assistant", "content": answer}) # Add LLM's answer to history
140
+
141
+ except Exception as e_invoke:
142
+ error_message = f"Error processing your query: {e_invoke}"
143
+ print(f"!!!!!!!!!! ERROR DURING RAG CHAIN INVOCATION (UI Level) !!!!!!!!!!")
144
+ import traceback
145
+ print(traceback.format_exc())
146
+ st.error(error_message)
147
+ st.session_state.messages.append({"role": "assistant", "content": f"Sorry, I encountered an error: {error_message}"})
148
+
149
+ print("--- app.py script finished a run ---")
150
 
151
  except Exception as e:
152
+ print(f"!!!!!!!!!! A FATAL ERROR OCCURRED DURING STARTUP !!!!!!!!!!")
 
153
  import traceback
154
  print(traceback.format_exc())
155
  st.error(f"A fatal error occurred during startup. Please check the container logs. Error: {e}")