viraj commited on
Commit
52e9bab
·
1 Parent(s): aa870e0

enhancements

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. app.py +50 -14
  3. rag_pipeline.py +88 -19
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
- __pycache__
 
 
 
1
  .env
2
+ __pycache__
3
+ chroma_db
4
+ files
app.py CHANGED
@@ -61,20 +61,56 @@ async def query_endpoint(request = Body(...)):
61
  raise HTTPException(status_code=422, detail="Missing file_id or question")
62
 
63
  retriever_path = f"{CHROMA_DIR}/{file_id}"
64
- # Load retriever from disk
65
  if not os.path.exists(retriever_path):
66
- return {"error": "Vectorstore for this file_id not found."}
67
 
68
- vectorstore = Chroma(
69
- embedding_function=embedding_model,
70
- persist_directory=retriever_path
71
- )
72
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
73
- retrieved_docs = retriever.invoke(selected_text or question)
74
- retrieved_context = "\n\n".join(
75
- re.sub(r"\s+", " ", doc.page_content.strip()) for doc in retrieved_docs
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- combined_context = f"User selected this:\n\"{selected_text}\"\n\nRelated parts from the document:\n{retrieved_context}"
79
- answer = answer_query(question, combined_context, explain_like_5)
80
- return {"answer": answer}
 
61
  raise HTTPException(status_code=422, detail="Missing file_id or question")
62
 
63
  retriever_path = f"{CHROMA_DIR}/{file_id}"
 
64
  if not os.path.exists(retriever_path):
65
+ raise HTTPException(status_code=404, detail="Vectorstore for this file_id not found.")
66
 
67
+ try:
68
+ # Initialize vectorstore with metadata filtering
69
+ vectorstore = Chroma(
70
+ embedding_function=embedding_model,
71
+ persist_directory=retriever_path
72
+ )
73
+
74
+ # Configure retriever with MMR search
75
+ retriever = vectorstore.as_retriever(
76
+ search_type="mmr",
77
+ search_kwargs={
78
+ "k": 4,
79
+ "fetch_k": 8,
80
+ "lambda_mult": 0.7,
81
+ }
82
+ )
83
+
84
+ # First, get context around selected text if it exists
85
+ contexts = []
86
+ if selected_text:
87
+ selected_results = retriever.invoke(selected_text)
88
+ contexts.extend([doc.page_content for doc in selected_results])
89
+
90
+ # Then get context for the question
91
+ question_results = retriever.invoke(question)
92
+ contexts.extend([doc.page_content for doc in question_results])
93
+
94
+ # Remove duplicates while preserving order
95
+ contexts = list(dict.fromkeys(contexts))
96
+
97
+ # Format the context with clear section separation
98
+ formatted_context = ""
99
+ if selected_text:
100
+ formatted_context += f"Selected Text Context:\n{selected_text}\n\n"
101
+
102
+ formatted_context += "Related Document Contexts:\n" + "\n---\n".join(
103
+ re.sub(r"\s+", " ", context.strip())
104
+ for context in contexts
105
+ )
106
+
107
+ # Get the answer using the enhanced context
108
+ answer = answer_query(question, formatted_context, explain_like_5)
109
+
110
+ return {
111
+ "answer": answer,
112
+ "context_used": formatted_context # Optionally return context for debugging
113
+ }
114
 
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
 
rag_pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
  import tempfile
2
  from langchain_chroma import Chroma
3
- from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  import os
6
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -19,36 +19,105 @@ def process_file(file_bytes, filename, file_id):
19
  tmp.write(file_bytes)
20
  tmp_path = tmp.name
21
 
22
- loader = PyPDFLoader(tmp_path) if ext == 'pdf' else None
23
- docs = loader.load()
 
 
 
 
 
 
24
 
25
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
 
 
 
 
 
 
 
26
  chunks = text_splitter.split_documents(docs)
 
 
 
 
 
 
 
 
 
 
27
 
 
28
  vectorstore = Chroma.from_documents(
29
  documents=chunks,
30
  embedding=embedding_model,
31
- persist_directory=f"{CHROMA_DIR}/{file_id}"
 
 
 
 
 
 
32
  )
33
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
 
 
 
 
 
 
 
 
 
 
34
  os.unlink(tmp_path)
35
  return retriever
36
 
37
 
38
  def answer_query(question, context, explain_like_5=False):
 
 
 
 
 
 
 
 
 
 
 
39
  system_prompt = (
40
- "You are a helpful assistant answering user queries based on provided document chunks.\n"
41
- "Only use the given context. If the answer is not found, respond with 'I don't know.'"
 
 
 
 
 
42
  )
 
43
  if explain_like_5:
44
- system_prompt += "\nExplain the answer in a simple way, like you're talking to a 5-year-old."
45
-
46
- # Step 2: Send to LLM
47
- response = client.chat.completions.create(
48
- model="llama-3.3-70b-versatile",
49
- messages=[
50
- {"role": "system", "content": system_prompt},
51
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
52
- ]
53
- )
54
- return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
1
  import tempfile
2
  from langchain_chroma import Chroma
3
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  import os
6
  from langchain_huggingface import HuggingFaceEmbeddings
 
19
  tmp.write(file_bytes)
20
  tmp_path = tmp.name
21
 
22
+ print("Processing file:", filename)
23
+ if ext == 'pdf':
24
+ loader = PyPDFLoader(tmp_path)
25
+ elif ext == 'txt':
26
+ loader = TextLoader(tmp_path, encoding='utf-8')
27
+ else:
28
+ os.unlink(tmp_path)
29
+ raise ValueError(f"Unsupported file type: .{ext}")
30
 
31
+ docs = loader.load()
32
+
33
+ # Enhanced text splitting strategy
34
+ text_splitter = RecursiveCharacterTextSplitter(
35
+ chunk_size=500, # Smaller chunks for more precise retrieval
36
+ chunk_overlap=50, # Reduced overlap but still maintaining context
37
+ length_function=len,
38
+ separators=["\n\n", "\n", " ", ""],
39
+ add_start_index=True # This helps maintain position information
40
+ )
41
+
42
  chunks = text_splitter.split_documents(docs)
43
+
44
+ # Enhance metadata for each chunk
45
+ for i, chunk in enumerate(chunks):
46
+ chunk.metadata.update({
47
+ "chunk_id": i,
48
+ "file_id": file_id,
49
+ "filename": filename,
50
+ "source": filename,
51
+ "chunk_type": "document"
52
+ })
53
 
54
+ # Create Chroma collection with enhanced metadata and filtering
55
  vectorstore = Chroma.from_documents(
56
  documents=chunks,
57
  embedding=embedding_model,
58
+ persist_directory=f"{CHROMA_DIR}/{file_id}",
59
+ collection_metadata={
60
+ "file_id": file_id,
61
+ "filename": filename,
62
+ "hnsw_space": "cosine", # Explicitly set distance metric
63
+ "document_type": ext
64
+ }
65
  )
66
+
67
+ # Configure retriever with metadata filtering capability
68
+ retriever = vectorstore.as_retriever(
69
+ search_type="mmr", # Use MMR for diversity in results
70
+ search_kwargs={
71
+ "k": 4,
72
+ "fetch_k": 8, # Fetch more candidates for MMR
73
+ "lambda_mult": 0.7, # Balance between relevance and diversity
74
+ }
75
+ )
76
+
77
  os.unlink(tmp_path)
78
  return retriever
79
 
80
 
81
  def answer_query(question, context, explain_like_5=False):
82
+ # Validate inputs
83
+ if not question or not context:
84
+ raise ValueError("Question and context must not be empty")
85
+
86
+ if not isinstance(context, (str, list)):
87
+ raise TypeError("Context must be a string or list")
88
+
89
+ # Format context if it's a list
90
+ if isinstance(context, list):
91
+ context = "\n\n".join(str(c) for c in context)
92
+
93
  system_prompt = (
94
+ "You are a helpful assistant answering user queries based STRICTLY on the provided document chunks.\n"
95
+ "IMPORTANT RULES:\n"
96
+ "1. ONLY use information from the given context. Do not use any external knowledge.\n"
97
+ "2. If the answer cannot be fully derived from the context, say 'I cannot answer this question based on the provided context.'\n"
98
+ "3. If you're unsure about any part of the answer, acknowledge the uncertainty.\n"
99
+ "4. Do not make assumptions beyond what's explicitly stated in the context.\n"
100
+ "5. Quote relevant parts of the context to support your answers when possible."
101
  )
102
+
103
  if explain_like_5:
104
+ system_prompt += "\nExplain the answer in a simple way, like you're talking to a 5-year-old, but still only use information from the context."
105
+ print("Context:", context)
106
+ try:
107
+ # Send to LLM with formatted prompt
108
+ response = client.chat.completions.create(
109
+ model="llama-3.3-70b-versatile",
110
+ messages=[
111
+ {"role": "system", "content": system_prompt},
112
+ {"role": "user", "content": (
113
+ f"Context:\n{context}\n\n"
114
+ f"Question: {question}\n\n"
115
+ "Remember to answer ONLY based on the information provided in the context above. "
116
+ "If you cannot find the answer in the context, say so explicitly."
117
+ )}
118
+ ],
119
+ temperature=0.3 # Lower temperature for more focused answers
120
+ )
121
+ return response.choices[0].message.content
122
+ except Exception as e:
123
+ raise Exception(f"Error generating answer: {str(e)}")