Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,38 +1,33 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
print("--- Python script starting ---")
|
| 4 |
-
|
| 5 |
-
import streamlit as st
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
-
import langchain
|
| 9 |
-
langchain.debug = True
|
| 10 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 11 |
-
os.environ['HF_HOME'] = '/app/huggingface_cache'
|
| 12 |
os.environ['TRANSFORMERS_CACHE'] = '/app/huggingface_cache/transformers'
|
| 13 |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/huggingface_cache/sentence_transformers'
|
| 14 |
-
# Create the directory if it doesn't exist, with permissions
|
| 15 |
if not os.path.exists('/app/huggingface_cache'):
|
| 16 |
os.makedirs('/app/huggingface_cache', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from dotenv import load_dotenv
|
| 18 |
from pinecone import Pinecone
|
| 19 |
|
| 20 |
-
# --- Standard Imports ---
|
| 21 |
from langchain_pinecone import PineconeVectorStore
|
| 22 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
| 23 |
from langchain_groq import ChatGroq
|
| 24 |
-
from langchain_core.prompts import
|
| 25 |
from langchain_core.runnables import RunnablePassthrough
|
| 26 |
-
from langchain_core.output_parsers import
|
| 27 |
-
from pydantic import BaseModel, Field
|
| 28 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 29 |
from langchain.retrievers.document_compressors import CohereRerank
|
| 30 |
|
| 31 |
print("--- All imports successful ---")
|
| 32 |
|
| 33 |
-
# We wrap the ENTIRE app in a try/except block to catch any startup error
|
| 34 |
try:
|
| 35 |
-
# --- Load Environment Variables ---
|
| 36 |
print("Step 1: Loading environment variables...")
|
| 37 |
load_dotenv()
|
| 38 |
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
|
|
@@ -41,47 +36,23 @@ try:
|
|
| 41 |
INDEX_NAME = "rag-chatbot"
|
| 42 |
print("Step 1: SUCCESS")
|
| 43 |
|
| 44 |
-
|
| 45 |
-
st.
|
| 46 |
-
st.title("π Production-Grade RAG System")
|
| 47 |
|
| 48 |
-
# --- Pydantic Model ---
|
| 49 |
-
class StructuredAnswer(BaseModel):
|
| 50 |
-
summary: str = Field(description="A concise summary.")
|
| 51 |
-
key_points: list[str] = Field(description="A list of key bullet points.")
|
| 52 |
-
confidence_score: float = Field(description="A 0.0 to 1.0 confidence score.")
|
| 53 |
-
|
| 54 |
-
# --- Caching and Initialization ---
|
| 55 |
@st.cache_resource
|
| 56 |
def initialize_services():
|
| 57 |
print("Step 2: Entering initialize_services function...")
|
| 58 |
if not all([PINECONE_API_KEY, GROQ_API_KEY, COHERE_API_KEY]):
|
| 59 |
raise ValueError("An API key is missing!")
|
| 60 |
-
|
| 61 |
-
print("Step 2a: Initializing embedding model...")
|
| 62 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 63 |
-
print("Step 2a: SUCCESS")
|
| 64 |
-
|
| 65 |
-
print("Step 2b: Initializing Pinecone client...")
|
| 66 |
pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
| 67 |
-
host = "https://rag-chatbot-sg8t88c.svc.aped-4627-b74a.pinecone.io"
|
| 68 |
index = pinecone.Index(host=host)
|
| 69 |
-
print("Step 2b: SUCCESS")
|
| 70 |
-
|
| 71 |
-
print("Step 2c: Creating PineconeVectorStore object...")
|
| 72 |
vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
print("Step 2d: Initializing Cohere Re-ranker...")
|
| 76 |
-
base_retriever = vectorstore.as_retriever(search_kwargs={'k': 20})
|
| 77 |
-
compressor = CohereRerank(cohere_api_key=COHERE_API_KEY, top_n=5, model="rerank-english-v3.0")
|
| 78 |
reranking_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
print("Step 2e: Initializing Groq LLM...")
|
| 82 |
-
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
|
| 83 |
-
print("Step 2e: SUCCESS")
|
| 84 |
-
|
| 85 |
print("Step 2: All services initialized successfully.")
|
| 86 |
return reranking_retriever, llm
|
| 87 |
|
|
@@ -89,92 +60,96 @@ try:
|
|
| 89 |
retriever, llm = initialize_services()
|
| 90 |
print("Step 3: SUCCESS, services are loaded.")
|
| 91 |
|
| 92 |
-
# --- RAG
|
| 93 |
print("Step 4: Defining RAG chain...")
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
Context:
|
| 101 |
{context}
|
| 102 |
-
|
| 103 |
-
Question:
|
| 104 |
-
{question}
|
| 105 |
-
|
| 106 |
-
Follow these formatting instructions precisely:
|
| 107 |
-
{format_instructions}
|
| 108 |
"""
|
| 109 |
-
prompt = PromptTemplate(
|
| 110 |
-
template=template,
|
| 111 |
-
input_variables=["context", "question"],
|
| 112 |
-
partial_variables={"format_instructions": format_instructions}
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
# --- NEW: Break down the chain for debugging ---
|
| 116 |
-
def retrieve_and_rerank(input_dict):
|
| 117 |
-
print(f"--- RAG DEBUG: Retrieving for question: {input_dict['question']} ---")
|
| 118 |
-
docs = retriever.invoke(input_dict['question'])
|
| 119 |
-
print(f"--- RAG DEBUG: Retrieved {len(docs)} docs after reranking ---")
|
| 120 |
-
for i, doc in enumerate(docs):
|
| 121 |
-
print(f" Doc {i} (source: {doc.metadata.get('source', 'N/A')}, page: {doc.metadata.get('page', 'N/A')}): {doc.page_content[:100]}...")
|
| 122 |
-
return {"context": docs, "question": input_dict['question']}
|
| 123 |
-
|
| 124 |
-
def format_prompt(input_dict):
|
| 125 |
-
print(f"--- RAG DEBUG: Formatting prompt with context ---")
|
| 126 |
-
# Manually construct the context string to see it clearly
|
| 127 |
-
context_str = "\n\n---\n\n".join([doc.page_content for doc in input_dict['context']])
|
| 128 |
-
print(f"--- RAG DEBUG: Context fed to LLM: {context_str[:500]}... ---") # Print first 500 chars of context
|
| 129 |
-
return prompt.invoke({"context": context_str, "question": input_dict['question']})
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
# Fallback: return a dictionary indicating failure, or just the raw string
|
| 148 |
-
return StructuredAnswer(summary="LLM output parsing failed. See logs.", key_points=[], confidence_score=0.0)
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
|
| 154 |
-
|
| 155 |
rag_chain = (
|
| 156 |
-
{"context": retriever, "question": RunnablePassthrough()}
|
| 157 |
| prompt
|
| 158 |
| llm
|
| 159 |
-
|
|
| 160 |
)
|
| 161 |
print("Step 4: SUCCESS")
|
| 162 |
|
| 163 |
-
# ---
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
with st.
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
except Exception as e:
|
| 176 |
-
|
| 177 |
-
print(f"!!!!!!!!!! A FATAL ERROR OCCURRED !!!!!!!!!!")
|
| 178 |
import traceback
|
| 179 |
print(traceback.format_exc())
|
| 180 |
st.error(f"A fatal error occurred during startup. Please check the container logs. Error: {e}")
|
|
|
|
| 1 |
+
%%writefile app.py
|
| 2 |
|
| 3 |
print("--- Python script starting ---")
|
|
|
|
|
|
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
os.environ['HF_HOME'] = '/app/huggingface_cache'
|
| 7 |
os.environ['TRANSFORMERS_CACHE'] = '/app/huggingface_cache/transformers'
|
| 8 |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/huggingface_cache/sentence_transformers'
|
|
|
|
| 9 |
if not os.path.exists('/app/huggingface_cache'):
|
| 10 |
os.makedirs('/app/huggingface_cache', exist_ok=True)
|
| 11 |
+
|
| 12 |
+
import langchain
|
| 13 |
+
langchain.debug = False # Turn off verbose RAG chain logging for production
|
| 14 |
+
|
| 15 |
+
import streamlit as st
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
from pinecone import Pinecone
|
| 18 |
|
|
|
|
| 19 |
from langchain_pinecone import PineconeVectorStore
|
| 20 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
| 21 |
from langchain_groq import ChatGroq
|
| 22 |
+
from langchain_core.prompts import ChatPromptTemplate # Use ChatPromptTemplate
|
| 23 |
from langchain_core.runnables import RunnablePassthrough
|
| 24 |
+
from langchain_core.output_parsers import StrOutputParser # Simpler string output
|
|
|
|
| 25 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 26 |
from langchain.retrievers.document_compressors import CohereRerank
|
| 27 |
|
| 28 |
print("--- All imports successful ---")
|
| 29 |
|
|
|
|
| 30 |
try:
|
|
|
|
| 31 |
print("Step 1: Loading environment variables...")
|
| 32 |
load_dotenv()
|
| 33 |
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
|
|
|
|
| 36 |
INDEX_NAME = "rag-chatbot"
|
| 37 |
print("Step 1: SUCCESS")
|
| 38 |
|
| 39 |
+
st.set_page_config(page_title="Advanced RAG Chatbot", page_icon="π", layout="wide")
|
| 40 |
+
st.title("π Production-Grade RAG Chatbot")
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
@st.cache_resource
|
| 43 |
def initialize_services():
|
| 44 |
print("Step 2: Entering initialize_services function...")
|
| 45 |
if not all([PINECONE_API_KEY, GROQ_API_KEY, COHERE_API_KEY]):
|
| 46 |
raise ValueError("An API key is missing!")
|
|
|
|
|
|
|
| 47 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
|
| 48 |
pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
| 49 |
+
host = "https://rag-chatbot-sg8t88c.svc.aped-4627-b74a.pinecone.io" # Your host
|
| 50 |
index = pinecone.Index(host=host)
|
|
|
|
|
|
|
|
|
|
| 51 |
vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
|
| 52 |
+
base_retriever = vectorstore.as_retriever(search_kwargs={'k': 10}) # Fetch 10 for reranker
|
| 53 |
+
compressor = CohereRerank(cohere_api_key=COHERE_API_KEY, top_n=3, model="rerank-english-02") # Rerank to top 3
|
|
|
|
|
|
|
|
|
|
| 54 |
reranking_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)
|
| 55 |
+
llm = ChatGroq(temperature=0.1, model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
print("Step 2: All services initialized successfully.")
|
| 57 |
return reranking_retriever, llm
|
| 58 |
|
|
|
|
| 60 |
retriever, llm = initialize_services()
|
| 61 |
print("Step 3: SUCCESS, services are loaded.")
|
| 62 |
|
| 63 |
+
# --- NEW RAG CHAIN with simpler output and source handling ---
|
| 64 |
print("Step 4: Defining RAG chain...")
|
| 65 |
+
|
| 66 |
+
# System prompt to guide the LLM for chat-like, sourced answers
|
| 67 |
+
system_prompt = """You are a helpful AI assistant that answers questions based ONLY on the provided context.
|
| 68 |
+
Your answer should be concise and directly address the question.
|
| 69 |
+
After your answer, list the numbers of the sources you used, like this: [1][2].
|
| 70 |
+
Do not make up information. If the answer is not in the context, say "I cannot answer this based on the provided documents."
|
| 71 |
|
| 72 |
Context:
|
| 73 |
{context}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 77 |
+
("system", system_prompt),
|
| 78 |
+
("human", "{question}")
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
+
def format_docs_with_numbers(docs):
|
| 82 |
+
# Prepend numbers to each document for citation
|
| 83 |
+
# Also limit the length of each doc to avoid overwhelming the LLM
|
| 84 |
+
MAX_DOC_LENGTH = 1500 # Max characters per document chunk
|
| 85 |
+
numbered_docs = []
|
| 86 |
+
for i, doc in enumerate(docs):
|
| 87 |
+
content = doc.page_content
|
| 88 |
+
if len(content) > MAX_DOC_LENGTH:
|
| 89 |
+
content = content[:MAX_DOC_LENGTH] + "..."
|
| 90 |
+
numbered_docs.append(f"Source [{i+1}]:\n{content}")
|
| 91 |
+
return "\n\n".join(numbered_docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
|
|
|
| 93 |
rag_chain = (
|
| 94 |
+
{"context": retriever | format_docs_with_numbers, "question": RunnablePassthrough()}
|
| 95 |
| prompt
|
| 96 |
| llm
|
| 97 |
+
| StrOutputParser()
|
| 98 |
)
|
| 99 |
print("Step 4: SUCCESS")
|
| 100 |
|
| 101 |
+
# --- Initialize chat history ---
|
| 102 |
+
if "messages" not in st.session_state:
|
| 103 |
+
st.session_state.messages = [{"role": "assistant", "content": "Hello! I'm ready to answer questions about your documents."}]
|
| 104 |
+
|
| 105 |
+
# Display chat messages
|
| 106 |
+
for message in st.session_state.messages:
|
| 107 |
+
with st.chat_message(message["role"]):
|
| 108 |
+
st.markdown(message["content"])
|
| 109 |
+
|
| 110 |
+
# Chat input
|
| 111 |
+
if user_query := st.chat_input("Ask a question about your documents"):
|
| 112 |
+
st.session_state.messages.append({"role": "user", "content": user_query})
|
| 113 |
+
with st.chat_message("user"):
|
| 114 |
+
st.markdown(user_query)
|
| 115 |
+
|
| 116 |
+
with st.chat_message("assistant"):
|
| 117 |
+
with st.spinner("Thinking..."):
|
| 118 |
+
try:
|
| 119 |
+
print(f"--- UI DEBUG: Invoking RAG chain with query: {user_query} ---")
|
| 120 |
+
answer = rag_chain.invoke(user_query)
|
| 121 |
+
print(f"--- UI DEBUG: Raw LLM Answer: {answer} ---")
|
| 122 |
+
|
| 123 |
+
st.markdown(answer) # Display the LLM's answer directly
|
| 124 |
+
|
| 125 |
+
# Retrieve sources again just for display (not ideal for performance but simple)
|
| 126 |
+
# In a more complex app, you'd pass source objects through the chain.
|
| 127 |
+
with st.expander("Sources"):
|
| 128 |
+
source_docs = retriever.invoke(user_query)
|
| 129 |
+
if source_docs:
|
| 130 |
+
for i, doc in enumerate(source_docs):
|
| 131 |
+
source_filename = os.path.basename(doc.metadata.get('source', 'Unknown'))
|
| 132 |
+
page_number = doc.metadata.get('page', 'N/A')
|
| 133 |
+
st.markdown(f"**[{i+1}] Source:** `{source_filename}` (Page: {page_number})")
|
| 134 |
+
st.markdown(f"> {doc.page_content[:300]}...") # Show a snippet
|
| 135 |
+
st.markdown("---")
|
| 136 |
+
else:
|
| 137 |
+
st.write("No specific sources were retrieved for this part of the answer.")
|
| 138 |
+
|
| 139 |
+
st.session_state.messages.append({"role": "assistant", "content": answer}) # Add LLM's answer to history
|
| 140 |
+
|
| 141 |
+
except Exception as e_invoke:
|
| 142 |
+
error_message = f"Error processing your query: {e_invoke}"
|
| 143 |
+
print(f"!!!!!!!!!! ERROR DURING RAG CHAIN INVOCATION (UI Level) !!!!!!!!!!")
|
| 144 |
+
import traceback
|
| 145 |
+
print(traceback.format_exc())
|
| 146 |
+
st.error(error_message)
|
| 147 |
+
st.session_state.messages.append({"role": "assistant", "content": f"Sorry, I encountered an error: {error_message}"})
|
| 148 |
+
|
| 149 |
+
print("--- app.py script finished a run ---")
|
| 150 |
|
| 151 |
except Exception as e:
|
| 152 |
+
print(f"!!!!!!!!!! A FATAL ERROR OCCURRED DURING STARTUP !!!!!!!!!!")
|
|
|
|
| 153 |
import traceback
|
| 154 |
print(traceback.format_exc())
|
| 155 |
st.error(f"A fatal error occurred during startup. Please check the container logs. Error: {e}")
|