JumaRubea commited on
Commit
e82b521
·
verified ·
1 Parent(s): 1258b31

Update rag_components.py

Browse files
Files changed (1) hide show
  1. rag_components.py +159 -12
rag_components.py CHANGED
@@ -6,7 +6,11 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.document_loaders import TextLoader
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.chains import RetrievalQA
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
10
 
11
  # Set cache directories for HuggingFace Spaces
12
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
@@ -18,6 +22,22 @@ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers_cache"
18
  for cache_dir in ["/tmp/huggingface_cache", "/tmp/transformers_cache", "/tmp/hf_hub_cache", "/tmp/sentence_transformers_cache"]:
19
  os.makedirs(cache_dir, exist_ok=True)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_documents(file_path: str):
22
  """Loads documents from a specified file path."""
23
  loader = TextLoader(file_path)
@@ -57,8 +77,7 @@ def setup_vector_store(docs, embeddings, persist_directory="./chroma_db"):
57
  return db.as_retriever()
58
 
59
  def create_qa_chain(retriever, model_name="microsoft/DialoGPT-medium"):
60
- """Creates the RetrievalQA chain with streaming capabilities.
61
- Using a smaller, more reliable model for HuggingFace Spaces."""
62
  try:
63
  tokenizer = AutoTokenizer.from_pretrained(
64
  model_name,
@@ -75,32 +94,52 @@ def create_qa_chain(retriever, model_name="microsoft/DialoGPT-medium"):
75
  cache_dir="/tmp/transformers_cache",
76
  device_map="auto",
77
  trust_remote_code=True,
78
- torch_dtype="auto" # Let it choose the best dtype
79
  )
80
 
 
81
  pipe = pipeline(
82
  "text-generation",
83
  model=model,
84
  tokenizer=tokenizer,
85
- max_new_tokens=256, # Reduced for faster generation
86
  temperature=0.7,
87
  top_p=0.9,
88
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
89
  )
90
 
91
  llm = HuggingFacePipeline(pipeline=pipe)
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  qa_chain = RetrievalQA.from_chain_type(
94
  llm=llm,
95
  retriever=retriever,
96
  chain_type="stuff",
97
- return_source_documents=True
 
98
  )
99
  return qa_chain
100
 
101
  except Exception as e:
102
  print(f"Error loading model {model_name}: {e}")
103
- # Try with an even smaller model as fallback
104
  try:
105
  print("Trying fallback model: distilgpt2")
106
  return create_qa_chain_fallback(retriever)
@@ -109,7 +148,7 @@ def create_qa_chain(retriever, model_name="microsoft/DialoGPT-medium"):
109
  raise e2
110
 
111
  def create_qa_chain_fallback(retriever):
112
- """Fallback QA chain with a very small model."""
113
  tokenizer = AutoTokenizer.from_pretrained(
114
  "distilgpt2",
115
  cache_dir="/tmp/transformers_cache"
@@ -125,17 +164,125 @@ def create_qa_chain_fallback(retriever):
125
  "text-generation",
126
  model=model,
127
  tokenizer=tokenizer,
128
- max_new_tokens=128,
129
  temperature=0.7,
130
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
131
  )
132
 
133
  llm = HuggingFacePipeline(pipeline=pipe)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  qa_chain = RetrievalQA.from_chain_type(
136
  llm=llm,
137
  retriever=retriever,
138
  chain_type="stuff",
139
- return_source_documents=True
 
140
  )
141
  return qa_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from langchain.document_loaders import TextLoader
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.chains import RetrievalQA
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.callbacks.base import BaseCallbackHandler
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
12
+ import streamlit as st
13
+ from typing import Any, Dict, List
14
 
15
  # Set cache directories for HuggingFace Spaces
16
  os.environ["HF_HOME"] = "/tmp/huggingface_cache"
 
22
  for cache_dir in ["/tmp/huggingface_cache", "/tmp/transformers_cache", "/tmp/hf_hub_cache", "/tmp/sentence_transformers_cache"]:
23
  os.makedirs(cache_dir, exist_ok=True)
24
 
25
+ class StreamingCallbackHandler(BaseCallbackHandler):
26
+ """Callback handler for streaming responses."""
27
+
28
+ def __init__(self, placeholder):
29
+ self.placeholder = placeholder
30
+ self.text = ""
31
+
32
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
33
+ """Handle new token from LLM."""
34
+ self.text += token
35
+ self.placeholder.markdown(self.text + "▌")
36
+
37
+ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
38
+ """Handle end of LLM response."""
39
+ self.placeholder.markdown(self.text)
40
+
41
  def load_documents(file_path: str):
42
  """Loads documents from a specified file path."""
43
  loader = TextLoader(file_path)
 
77
  return db.as_retriever()
78
 
79
  def create_qa_chain(retriever, model_name="microsoft/DialoGPT-medium"):
80
+ """Creates an enhanced QA chain with better prompting and streaming capabilities."""
 
81
  try:
82
  tokenizer = AutoTokenizer.from_pretrained(
83
  model_name,
 
94
  cache_dir="/tmp/transformers_cache",
95
  device_map="auto",
96
  trust_remote_code=True,
97
+ torch_dtype="auto"
98
  )
99
 
100
+ # Create pipeline with better parameters to reduce repetition
101
  pipe = pipeline(
102
  "text-generation",
103
  model=model,
104
  tokenizer=tokenizer,
105
+ max_new_tokens=150,
106
  temperature=0.7,
107
  top_p=0.9,
108
+ top_k=40,
109
+ repetition_penalty=1.2, # Reduce repetition
110
+ do_sample=True,
111
+ pad_token_id=tokenizer.eos_token_id,
112
+ eos_token_id=tokenizer.eos_token_id,
113
+ return_full_text=False # Only return new tokens
114
  )
115
 
116
  llm = HuggingFacePipeline(pipeline=pipe)
117
 
118
+ # Enhanced prompt template for better QA responses
119
+ prompt_template = """You're Juma's Assistant. Use the following context to answer the user's question. If you cannot answer based on the context, say so clearly.
120
+
121
+ Context: {context}
122
+
123
+ Question: {question}
124
+
125
+ Answer: Let me help you with that based on the information provided."""
126
+
127
+ prompt = PromptTemplate(
128
+ template=prompt_template,
129
+ input_variables=["context", "question"]
130
+ )
131
+
132
  qa_chain = RetrievalQA.from_chain_type(
133
  llm=llm,
134
  retriever=retriever,
135
  chain_type="stuff",
136
+ return_source_documents=True,
137
+ chain_type_kwargs={"prompt": prompt}
138
  )
139
  return qa_chain
140
 
141
  except Exception as e:
142
  print(f"Error loading model {model_name}: {e}")
 
143
  try:
144
  print("Trying fallback model: distilgpt2")
145
  return create_qa_chain_fallback(retriever)
 
148
  raise e2
149
 
150
  def create_qa_chain_fallback(retriever):
151
+ """Fallback QA chain with a very small model and better parameters."""
152
  tokenizer = AutoTokenizer.from_pretrained(
153
  "distilgpt2",
154
  cache_dir="/tmp/transformers_cache"
 
164
  "text-generation",
165
  model=model,
166
  tokenizer=tokenizer,
167
+ max_new_tokens=100,
168
  temperature=0.7,
169
+ top_p=0.9,
170
+ top_k=40,
171
+ repetition_penalty=1.3,
172
+ do_sample=True,
173
+ pad_token_id=tokenizer.eos_token_id,
174
+ eos_token_id=tokenizer.eos_token_id,
175
+ return_full_text=False
176
  )
177
 
178
  llm = HuggingFacePipeline(pipeline=pipe)
179
 
180
+ # Same enhanced prompt
181
+ prompt_template = """You're Juma's Assistant. Use the following context to answer the user's question. If you cannot answer based on the context, say so clearly.
182
+
183
+ Context: {context}
184
+
185
+ Question: {question}
186
+
187
+ Answer: Let me help you with that based on the information provided."""
188
+
189
+ prompt = PromptTemplate(
190
+ template=prompt_template,
191
+ input_variables=["context", "question"]
192
+ )
193
+
194
  qa_chain = RetrievalQA.from_chain_type(
195
  llm=llm,
196
  retriever=retriever,
197
  chain_type="stuff",
198
+ return_source_documents=True,
199
+ chain_type_kwargs={"prompt": prompt}
200
  )
201
  return qa_chain
202
+
203
+ def create_streaming_response(qa_chain, question: str, placeholder):
204
+ """Create a streaming response using the QA chain."""
205
+ try:
206
+ # Get the response first
207
+ result = qa_chain.invoke({"query": question})
208
+
209
+ # Extract just the answer part
210
+ answer = result.get("result", "")
211
+
212
+ # Clean up the response
213
+ answer = clean_response(answer)
214
+
215
+ # Simulate streaming by displaying character by character
216
+ import time
217
+ displayed_text = ""
218
+
219
+ for i, char in enumerate(answer):
220
+ displayed_text += char
221
+ placeholder.markdown(displayed_text + "▌")
222
+
223
+ # Add small delay for streaming effect
224
+ if i % 3 == 0: # Every 3 characters
225
+ time.sleep(0.02) # 20ms delay
226
+
227
+ # Final display without cursor
228
+ placeholder.markdown(displayed_text)
229
+
230
+ return displayed_text
231
+
232
+ except Exception as e:
233
+ placeholder.error(f"Error generating response: {e}")
234
+ return "I apologize, but I encountered an error while processing your question."
235
+
236
+ def clean_response(text: str) -> str:
237
+ """Clean up the response to remove repetition and improve quality."""
238
+ if not text:
239
+ return "I couldn't find relevant information to answer your question."
240
+
241
+ # Remove the prompt part if it's included in the response
242
+ if "Answer: Let me help you with that based on the information provided." in text:
243
+ text = text.split("Answer: Let me help you with that based on the information provided.", 1)[-1].strip()
244
+
245
+ # Remove common prefixes that models add
246
+ prefixes_to_remove = [
247
+ "Based on the context provided,",
248
+ "According to the document,",
249
+ "The document states that",
250
+ "From the information given,",
251
+ "Let me help you with that based on the information provided."
252
+ ]
253
+
254
+ for prefix in prefixes_to_remove:
255
+ if text.startswith(prefix):
256
+ text = text[len(prefix):].strip()
257
+
258
+ # Split into sentences and remove repetitive ones
259
+ sentences = text.split('.')
260
+ cleaned_sentences = []
261
+
262
+ for sentence in sentences:
263
+ sentence = sentence.strip()
264
+ if sentence and len(sentence) > 10: # Filter out very short fragments
265
+ # Check if this sentence is too similar to recent ones
266
+ is_repetitive = False
267
+ for recent in cleaned_sentences[-2:]:
268
+ if len(set(sentence.split()) & set(recent.split())) > len(sentence.split()) * 0.7:
269
+ is_repetitive = True
270
+ break
271
+
272
+ if not is_repetitive:
273
+ cleaned_sentences.append(sentence)
274
+
275
+ # Join sentences back
276
+ result = '. '.join(cleaned_sentences)
277
+
278
+ # Ensure it ends properly
279
+ if result and not result.endswith('.'):
280
+ result += '.'
281
+
282
+ # Limit length and ensure quality
283
+ if len(result) > 500:
284
+ # Cut at sentence boundary
285
+ sentences = result[:500].split('.')
286
+ result = '. '.join(sentences[:-1]) + '.'
287
+
288
+ return result if result.strip() else "I couldn't generate a proper response. Please try rephrasing your question."