alx-d commited on
Commit
cc814aa
ยท
verified ยท
1 Parent(s): 43a5cae

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +150 -133
advanced_rag.py CHANGED
@@ -19,14 +19,13 @@ from langchain_community.retrievers import BM25Retriever
19
  from langchain.retrievers import EnsembleRetriever
20
  from langchain.prompts import ChatPromptTemplate
21
  from langchain.schema import StrOutputParser, Document
22
- from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
23
  from transformers.quantizers.auto import AutoQuantizationConfig
24
  import gradio as gr
25
  import requests
26
 
27
  # Add Mistral imports with fallback handling
28
  try:
29
- # Try importing from the latest package structure
30
  from mistralai import Mistral
31
  MISTRAL_AVAILABLE = True
32
  debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
@@ -36,14 +35,13 @@ except ImportError:
36
  debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
37
  debug_print("Mistral client library not found. Install with: pip install mistralai")
38
 
39
- # Debug print function (already defined above in the try block)
40
  def debug_print(message: str):
41
  print(f"[{datetime.datetime.now().isoformat()}] {message}")
42
 
43
  def word_count(text: str) -> int:
44
  return len(text.split())
45
 
46
- # Initialize tokenizer for counting
47
  def initialize_tokenizer():
48
  try:
49
  return AutoTokenizer.from_pretrained("gpt2")
@@ -61,7 +59,20 @@ def count_tokens(text: str) -> int:
61
  return len(text.split())
62
  return len(text.split())
63
 
64
- # Updated prompt template to include conversation history
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  default_prompt = """\
66
  {conversation_history}
67
  Use the following context to provide a detailed technical answer to the user's question.
@@ -75,7 +86,6 @@ User's question:
75
  {question}
76
  """
77
 
78
- # Helper function to load TXT files from URL with error checking
79
  def load_txt_from_url(url: str) -> Document:
80
  response = requests.get(url)
81
  if response.status_code == 200:
@@ -86,18 +96,10 @@ def load_txt_from_url(url: str) -> Document:
86
  else:
87
  raise Exception(f"Failed to load {url} with status {response.status_code}")
88
 
89
-
90
  class ElevatedRagChain:
91
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
92
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
93
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
94
-
95
- # Check for required API keys based on model choice
96
- if "mistral-api" in llm_choice.lower() and not os.environ.get("MISTRAL_API_KEY"):
97
- debug_print("WARNING: Mistral API selected but MISTRAL_API_KEY environment variable not set")
98
- if not MISTRAL_AVAILABLE:
99
- debug_print("WARNING: Mistral API package not installed. Install with: pip install mistralai")
100
-
101
  self.embed_func = HuggingFaceEmbeddings(
102
  model_name="sentence-transformers/all-MiniLM-L6-v2",
103
  model_kwargs={"device": "cpu"}
@@ -110,29 +112,45 @@ class ElevatedRagChain:
110
  self.top_p = top_p
111
  self.prompt_template = prompt_template
112
  self.context = ""
113
- self.conversation_history: List[Dict[str, str]] = [] # List of dicts with keys "query" and "response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def create_llm_pipeline(self):
116
- if "remote" in self.llm_choice.lower():
 
117
  debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
118
  from huggingface_hub import InferenceClient
119
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
120
  hf_api_token = os.environ.get("HF_API_TOKEN")
121
  if not hf_api_token:
122
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
123
- client = InferenceClient(token=hf_api_token)
124
-
125
  def remote_generate(prompt: str) -> str:
126
  response = client.text_generation(
127
  prompt,
128
  model=repo_id,
129
- # max_new_tokens=512,
130
  temperature=self.temperature,
131
  top_p=self.top_p,
132
  repetition_penalty=1.1
133
  )
134
  return response
135
-
136
  from langchain.llms.base import LLM
137
  class RemoteLLM(LLM):
138
  @property
@@ -145,76 +163,94 @@ class ElevatedRagChain:
145
  return {"model": repo_id}
146
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
147
  return RemoteLLM()
148
- elif "mistral-api" in self.llm_choice.lower():
149
  debug_print("Creating Mistral API pipeline...")
150
-
151
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
152
  if not mistral_api_key:
153
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
154
-
155
  if not MISTRAL_AVAILABLE:
156
  raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
157
-
158
- # Initialize the Mistral client with latest API
159
- mistral_client = Mistral(api_key=mistral_api_key)
160
-
161
- # Define the model to use - updated to match current model names
162
- mistral_model = "mistral-small-latest"
163
-
164
  from langchain.llms.base import LLM
165
  class MistralLLM(LLM):
166
  temperature: float = 0.7
167
  top_p: float = 0.95
168
-
169
  def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95):
170
- super().__init__() # Important to call the parent constructor
171
- self.client = Mistral(api_key=api_key)
172
  self.temperature = temperature
173
  self.top_p = top_p
174
-
175
  @property
176
  def _llm_type(self) -> str:
177
  return "mistral_llm"
178
-
179
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
180
- response = self.client.chat.complete(
181
- model="mistral-small-latest", # Replace with the actual model name if different
182
  messages=[{"role": "user", "content": prompt}],
183
  temperature=self.temperature,
184
  top_p=self.top_p,
185
  max_tokens=512
186
  )
187
  return response.choices[0].message.content
188
-
189
  @property
190
  def _identifying_params(self) -> dict:
191
  return {"model": "mistral-small-latest"}
192
-
193
- # Initialize and return the MistralLLM instance
194
  mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
195
  debug_print("Mistral API pipeline created successfully.")
196
  return mistral_llm
197
-
198
  else:
 
199
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
200
- if "deepseek" in self.llm_choice.lower():
201
- model_id = "deepseek-ai/DeepSeek-R1"
202
- elif "gemini" in self.llm_choice.lower():
203
- model_id = "gemini/flash-1.5"
204
- elif "mistralai" in self.llm_choice.lower():
205
- model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
206
-
207
  pipe = pipeline(
208
  "text-generation",
209
  model=model_id,
210
  model_kwargs={"torch_dtype": torch.bfloat16},
211
- max_length=4096,
212
  do_sample=True,
213
  temperature=self.temperature,
214
  top_p=self.top_p,
215
- device=-1
 
216
  )
217
- return HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
220
  debug_print(f"Processing files using {self.llm_choice}")
@@ -222,7 +258,6 @@ class ElevatedRagChain:
222
  for link in file_links:
223
  if link.lower().endswith(".pdf"):
224
  debug_print(f"Loading PDF: {link}")
225
- # Ensure that the PDF loader returns a non-empty list.
226
  loaded_docs = OnlinePDFLoader(link).load()
227
  if loaded_docs:
228
  self.raw_data.append(loaded_docs[0])
@@ -236,79 +271,49 @@ class ElevatedRagChain:
236
  debug_print(f"Error loading TXT file {link}: {e}")
237
  else:
238
  debug_print(f"File type not supported for URL: {link}")
239
-
240
  if not self.raw_data:
241
  raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
242
-
243
  debug_print("Files loaded successfully.")
244
-
245
  debug_print("Starting text splitting...")
246
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
247
  self.split_data = self.text_splitter.split_documents(self.raw_data)
248
  if not self.split_data:
249
  raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
250
  debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
251
-
252
  debug_print("Creating BM25 retriever...")
253
  self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
254
  self.bm25_retriever.k = self.top_k
255
  debug_print("BM25 retriever created.")
256
-
257
  debug_print("Embedding chunks and creating FAISS vector store...")
258
  self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
259
  self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
260
  debug_print("FAISS vector store created successfully.")
261
-
262
- ensemble = EnsembleRetriever(
263
  retrievers=[self.bm25_retriever, self.faiss_retriever],
264
  weights=[self.bm25_weight, self.faiss_weight]
265
  )
266
-
267
- def capture_context(result):
268
- # Convert each Document to a string and update the context.
269
- self.context = "\n".join([str(doc) for doc in result["context"]])
270
- result["context"] = self.context
271
- # Add conversation_history from self.conversation_history (if any) as a string.
272
- history_text = (
273
- "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
274
- if self.conversation_history else ""
275
- )
276
- result["conversation_history"] = history_text
277
- return result
278
-
279
- def extract_question(input_data):
280
- # Expecting input_data to be a dict with a key "question"
281
- return input_data["question"]
282
-
283
- # Build the chain so that the ensemble (BM25 + FAISS) gets only the question string.
284
  base_runnable = RunnableParallel({
285
- "context": RunnableLambda(extract_question) | ensemble,
286
- "question": RunnableLambda(extract_question)
287
- }) | capture_context
288
-
289
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
290
  self.str_output_parser = StrOutputParser()
291
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
292
  self.llm = self.create_llm_pipeline()
293
-
294
  def format_response(response: str) -> str:
295
  input_tokens = count_tokens(self.context + self.prompt_template)
296
  output_tokens = count_tokens(response)
297
- # Format the response as Markdown for better visual rendering
298
  formatted = f"### Response\n\n{response}\n\n---\n"
299
  formatted += f"- **Input tokens:** {input_tokens}\n"
300
  formatted += f"- **Output tokens:** {output_tokens}\n"
301
  formatted += f"- **Generated using:** {self.llm_choice}\n"
302
- # Append conversation history summary
303
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
304
  return formatted
305
-
306
  self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
307
  debug_print("Elevated RAG chain successfully built and ready to use.")
308
-
309
  def get_current_context(self) -> str:
310
- # Show a sample of the document context along with a summary of conversation history.
311
- base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if hasattr(self, "split_data") and self.split_data else "No context available."
312
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
313
  recent = self.conversation_history[-3:]
314
  if recent:
@@ -332,23 +337,33 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
332
  try:
333
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
334
  global rag_chain
335
- rag_chain = ElevatedRagChain(
336
- llm_choice=model_choice,
337
- prompt_template=prompt_template,
338
- bm25_weight=bm25_weight,
339
- temperature=temperature,
340
- top_p=top_p
341
- )
342
- rag_chain.add_pdfs_to_vectore_store(links)
343
- context_display = rag_chain.get_current_context()
344
- response_msg = f"Files loaded successfully. Using model: {model_choice}"
345
- debug_print(response_msg)
346
- return (
347
- response_msg,
348
- f"Word count: {word_count(rag_chain.context)}",
349
- f"Model used: {rag_chain.llm_choice}",
350
- f"Context:\n{context_display}"
351
- )
 
 
 
 
 
 
 
 
 
 
352
  except Exception as e:
353
  error_msg = traceback.format_exc()
354
  debug_print("Could not load files. Error: " + error_msg)
@@ -359,6 +374,16 @@ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, te
359
  "Context: N/A"
360
  )
361
 
 
 
 
 
 
 
 
 
 
 
362
  def submit_query_updated(query):
363
  debug_print("Inside submit_query function.")
364
  if not query:
@@ -366,20 +391,15 @@ def submit_query_updated(query):
366
  return "Please enter a non-empty query", "Word count: 0", f"Model used: {rag_chain.llm_choice}", ""
367
  if hasattr(rag_chain, 'elevated_rag_chain'):
368
  try:
369
- # Incorporate conversation history by joining previous Q&A pairs.
370
- history_text = ""
371
- if rag_chain.conversation_history:
372
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in rag_chain.conversation_history])
373
-
374
- # Build the prompt variables dictionary for the chain.
375
  prompt_variables = {
376
  "conversation_history": history_text,
377
  "context": rag_chain.context,
378
  "question": query
379
  }
380
-
 
381
  response = rag_chain.elevated_rag_chain.invoke(prompt_variables)
382
- # Save the current conversation to history
383
  rag_chain.conversation_history.append({"query": query, "response": response})
384
  input_token_count = count_tokens(query)
385
  output_token_count = count_tokens(response)
@@ -419,11 +439,9 @@ def reset_app_updated():
419
  # Gradio Interface Setup
420
  # ----------------------------
421
  custom_css = """
422
- button {
423
- background-color: grey !important;
424
- font-family: Arial !important;
425
- font-weight: bold !important;
426
- color: blue !important;
427
  }
428
  """
429
 
@@ -435,31 +453,24 @@ with gr.Blocks(css=custom_css) as app:
435
  - ๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3
436
  - ๐Ÿ‡ช๐Ÿ‡บ Mistral-API
437
 
438
- **๐Ÿ”ฅ Randomness (Temperature):** Temperature adjusts how predictable or varied the output is. A low temperature makes the model choose very predictable words (which can be repetitive), while a high temperature introduces more randomness for diverse, creative text.
439
 
440
- **๐ŸŽฏ Word Variety (Topโ€‘p):** Topโ€‘p limits the modelโ€™s word choices to those that make up a set percentage (p) of the total probability. Lower values yield focused outputs; higher values increase variety and creativity.
441
 
442
- **โœ๏ธ Prompt Template:** Edit the prompt template if desired.
443
 
444
- **๐Ÿ”— File URLs:** Enter one or more file URLs (PDF or TXT, one per line).
445
 
446
- **โš–๏ธ Weight Controls:** Adjust Lexical vs Semantics (BM25 Weight).
447
 
448
  **๐Ÿ” Query:** Enter your query below.
449
 
450
- The response displays the model used, word count, and the current context (including conversation history).
451
- """
452
- ''')
453
  with gr.Row():
454
  with gr.Column():
455
  model_dropdown = gr.Dropdown(
456
- choices=[
457
- "๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
458
- "๐Ÿ‡ช๐Ÿ‡บ Mistral-API"
459
- # "DeepSeek-R1", # Option commented out
460
- # "Gemini Flash 1.5", # Option commented out
461
- # "Mistralai/Mistral-Small-24B-Instruct-2501" # Option commented out
462
- ],
463
  value="๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
464
  label="Select Model"
465
  )
@@ -535,6 +546,12 @@ The response displays the model used, word count, and the current context (inclu
535
  inputs=[],
536
  outputs=[response_output, context_output, model_output]
537
  )
 
 
 
 
 
 
538
 
539
  if __name__ == "__main__":
540
  debug_print("Launching Gradio interface.")
 
19
  from langchain.retrievers import EnsembleRetriever
20
  from langchain.prompts import ChatPromptTemplate
21
  from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnableLambda
23
  from transformers.quantizers.auto import AutoQuantizationConfig
24
  import gradio as gr
25
  import requests
26
 
27
  # Add Mistral imports with fallback handling
28
  try:
 
29
  from mistralai import Mistral
30
  MISTRAL_AVAILABLE = True
31
  debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
 
35
  debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
36
  debug_print("Mistral client library not found. Install with: pip install mistralai")
37
 
 
38
  def debug_print(message: str):
39
  print(f"[{datetime.datetime.now().isoformat()}] {message}")
40
 
41
  def word_count(text: str) -> int:
42
  return len(text.split())
43
 
44
+ # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
45
  def initialize_tokenizer():
46
  try:
47
  return AutoTokenizer.from_pretrained("gpt2")
 
59
  return len(text.split())
60
  return len(text.split())
61
 
62
+ def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
63
+ if global_tokenizer:
64
+ try:
65
+ tokens = global_tokenizer.encode(prompt)
66
+ if len(tokens) > max_tokens:
67
+ tokens = tokens[-max_tokens:] # keep the last max_tokens tokens
68
+ return global_tokenizer.decode(tokens)
69
+ except Exception as e:
70
+ debug_print("Truncation error: " + str(e))
71
+ words = prompt.split()
72
+ if len(words) > max_tokens:
73
+ return " ".join(words[-max_tokens:])
74
+ return prompt
75
+
76
  default_prompt = """\
77
  {conversation_history}
78
  Use the following context to provide a detailed technical answer to the user's question.
 
86
  {question}
87
  """
88
 
 
89
  def load_txt_from_url(url: str) -> Document:
90
  response = requests.get(url)
91
  if response.status_code == 200:
 
96
  else:
97
  raise Exception(f"Failed to load {url} with status {response.status_code}")
98
 
 
99
  class ElevatedRagChain:
100
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
101
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
102
  debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
 
 
 
 
 
 
 
103
  self.embed_func = HuggingFaceEmbeddings(
104
  model_name="sentence-transformers/all-MiniLM-L6-v2",
105
  model_kwargs={"device": "cpu"}
 
112
  self.top_p = top_p
113
  self.prompt_template = prompt_template
114
  self.context = ""
115
+ self.conversation_history: List[Dict[str, str]] = []
116
+ self.raw_data = None
117
+ self.split_data = None
118
+ self.elevated_rag_chain = None
119
+
120
+ # Instance method to capture context and conversation history
121
+ def capture_context(self, result):
122
+ self.context = "\n".join([str(doc) for doc in result["context"]])
123
+ result["context"] = self.context
124
+ history_text = (
125
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
126
+ if self.conversation_history else ""
127
+ )
128
+ result["conversation_history"] = history_text
129
+ return result
130
+
131
+ # Instance method to extract question from input data
132
+ def extract_question(self, input_data):
133
+ return input_data["question"]
134
 
135
  def create_llm_pipeline(self):
136
+ normalized = self.llm_choice.lower()
137
+ if "remote" in normalized:
138
  debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
139
  from huggingface_hub import InferenceClient
140
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
141
  hf_api_token = os.environ.get("HF_API_TOKEN")
142
  if not hf_api_token:
143
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
144
+ client = InferenceClient(token=hf_api_token, timeout=180)
 
145
  def remote_generate(prompt: str) -> str:
146
  response = client.text_generation(
147
  prompt,
148
  model=repo_id,
 
149
  temperature=self.temperature,
150
  top_p=self.top_p,
151
  repetition_penalty=1.1
152
  )
153
  return response
 
154
  from langchain.llms.base import LLM
155
  class RemoteLLM(LLM):
156
  @property
 
163
  return {"model": repo_id}
164
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
165
  return RemoteLLM()
166
+ elif "mistral-api" in normalized:
167
  debug_print("Creating Mistral API pipeline...")
 
168
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
169
  if not mistral_api_key:
170
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
 
171
  if not MISTRAL_AVAILABLE:
172
  raise ImportError("Mistral client library not installed. Install with: pip install mistralai")
 
 
 
 
 
 
 
173
  from langchain.llms.base import LLM
174
  class MistralLLM(LLM):
175
  temperature: float = 0.7
176
  top_p: float = 0.95
177
+ _client: Any = None
178
  def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95):
179
+ super().__init__()
180
+ self._client = Mistral(api_key=api_key)
181
  self.temperature = temperature
182
  self.top_p = top_p
 
183
  @property
184
  def _llm_type(self) -> str:
185
  return "mistral_llm"
 
186
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
187
+ response = self._client.chat.complete(
188
+ model="mistral-small-latest",
189
  messages=[{"role": "user", "content": prompt}],
190
  temperature=self.temperature,
191
  top_p=self.top_p,
192
  max_tokens=512
193
  )
194
  return response.choices[0].message.content
 
195
  @property
196
  def _identifying_params(self) -> dict:
197
  return {"model": "mistral-small-latest"}
 
 
198
  mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
199
  debug_print("Mistral API pipeline created successfully.")
200
  return mistral_llm
 
201
  else:
202
+ # Default branch: assume Llama
203
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
204
+ extra_kwargs = {}
205
+ if "llama" in normalized or model_id.startswith("meta-llama"):
206
+ extra_kwargs["max_length"] = 4096
 
 
 
 
207
  pipe = pipeline(
208
  "text-generation",
209
  model=model_id,
210
  model_kwargs={"torch_dtype": torch.bfloat16},
 
211
  do_sample=True,
212
  temperature=self.temperature,
213
  top_p=self.top_p,
214
+ device=-1,
215
+ **extra_kwargs
216
  )
217
+ from langchain.llms.base import LLM
218
+ class LocalLLM(LLM):
219
+ @property
220
+ def _llm_type(self) -> str:
221
+ return "local_llm"
222
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
223
+ return pipe(prompt)[0]["generated_text"]
224
+ @property
225
+ def _identifying_params(self) -> dict:
226
+ return {"model": model_id, "max_length": extra_kwargs.get("max_length")}
227
+ debug_print("Local Llama pipeline created successfully with max_length=4096.")
228
+ return LocalLLM()
229
+
230
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
231
+ debug_print(f"Updating chain with new model: {new_model_choice}")
232
+ self.llm_choice = new_model_choice
233
+ self.temperature = temperature
234
+ self.top_p = top_p
235
+ self.prompt_template = prompt_template
236
+ self.bm25_weight = bm25_weight
237
+ self.faiss_weight = 1.0 - bm25_weight
238
+ self.llm = self.create_llm_pipeline()
239
+ def format_response(response: str) -> str:
240
+ input_tokens = count_tokens(self.context + self.prompt_template)
241
+ output_tokens = count_tokens(response)
242
+ formatted = f"### Response\n\n{response}\n\n---\n"
243
+ formatted += f"- **Input tokens:** {input_tokens}\n"
244
+ formatted += f"- **Output tokens:** {output_tokens}\n"
245
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
246
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
247
+ return formatted
248
+ base_runnable = RunnableParallel({
249
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
250
+ "question": RunnableLambda(self.extract_question)
251
+ }) | self.capture_context
252
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
253
+ debug_print("Chain updated successfully with new LLM pipeline.")
254
 
255
  def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
256
  debug_print(f"Processing files using {self.llm_choice}")
 
258
  for link in file_links:
259
  if link.lower().endswith(".pdf"):
260
  debug_print(f"Loading PDF: {link}")
 
261
  loaded_docs = OnlinePDFLoader(link).load()
262
  if loaded_docs:
263
  self.raw_data.append(loaded_docs[0])
 
271
  debug_print(f"Error loading TXT file {link}: {e}")
272
  else:
273
  debug_print(f"File type not supported for URL: {link}")
 
274
  if not self.raw_data:
275
  raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
 
276
  debug_print("Files loaded successfully.")
 
277
  debug_print("Starting text splitting...")
278
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
279
  self.split_data = self.text_splitter.split_documents(self.raw_data)
280
  if not self.split_data:
281
  raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
282
  debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
 
283
  debug_print("Creating BM25 retriever...")
284
  self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
285
  self.bm25_retriever.k = self.top_k
286
  debug_print("BM25 retriever created.")
 
287
  debug_print("Embedding chunks and creating FAISS vector store...")
288
  self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
289
  self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
290
  debug_print("FAISS vector store created successfully.")
291
+ self.ensemble_retriever = EnsembleRetriever(
 
292
  retrievers=[self.bm25_retriever, self.faiss_retriever],
293
  weights=[self.bm25_weight, self.faiss_weight]
294
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  base_runnable = RunnableParallel({
296
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
297
+ "question": RunnableLambda(self.extract_question)
298
+ }) | self.capture_context
 
299
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
300
  self.str_output_parser = StrOutputParser()
301
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
302
  self.llm = self.create_llm_pipeline()
 
303
  def format_response(response: str) -> str:
304
  input_tokens = count_tokens(self.context + self.prompt_template)
305
  output_tokens = count_tokens(response)
 
306
  formatted = f"### Response\n\n{response}\n\n---\n"
307
  formatted += f"- **Input tokens:** {input_tokens}\n"
308
  formatted += f"- **Output tokens:** {output_tokens}\n"
309
  formatted += f"- **Generated using:** {self.llm_choice}\n"
 
310
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
311
  return formatted
 
312
  self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
313
  debug_print("Elevated RAG chain successfully built and ready to use.")
314
+
315
  def get_current_context(self) -> str:
316
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
 
317
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
318
  recent = self.conversation_history[-3:]
319
  if recent:
 
337
  try:
338
  links = [link.strip() for link in file_links.split("\n") if link.strip()]
339
  global rag_chain
340
+ if rag_chain.raw_data:
341
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
342
+ context_display = rag_chain.get_current_context()
343
+ response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
344
+ return (
345
+ response_msg,
346
+ f"Word count: {word_count(rag_chain.context)}",
347
+ f"Model used: {rag_chain.llm_choice}",
348
+ f"Context:\n{context_display}"
349
+ )
350
+ else:
351
+ rag_chain = ElevatedRagChain(
352
+ llm_choice=model_choice,
353
+ prompt_template=prompt_template,
354
+ bm25_weight=bm25_weight,
355
+ temperature=temperature,
356
+ top_p=top_p
357
+ )
358
+ rag_chain.add_pdfs_to_vectore_store(links)
359
+ context_display = rag_chain.get_current_context()
360
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
361
+ return (
362
+ response_msg,
363
+ f"Word count: {word_count(rag_chain.context)}",
364
+ f"Model used: {rag_chain.llm_choice}",
365
+ f"Context:\n{context_display}"
366
+ )
367
  except Exception as e:
368
  error_msg = traceback.format_exc()
369
  debug_print("Could not load files. Error: " + error_msg)
 
374
  "Context: N/A"
375
  )
376
 
377
+ def update_model(new_model: str):
378
+ global rag_chain
379
+ if rag_chain and rag_chain.raw_data:
380
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
381
+ rag_chain.prompt_template, rag_chain.bm25_weight)
382
+ debug_print(f"Model updated to {rag_chain.llm_choice}")
383
+ return f"Model updated to: {rag_chain.llm_choice}"
384
+ else:
385
+ return "No files loaded; please load files first."
386
+
387
  def submit_query_updated(query):
388
  debug_print("Inside submit_query function.")
389
  if not query:
 
391
  return "Please enter a non-empty query", "Word count: 0", f"Model used: {rag_chain.llm_choice}", ""
392
  if hasattr(rag_chain, 'elevated_rag_chain'):
393
  try:
394
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in rag_chain.conversation_history]) if rag_chain.conversation_history else ""
 
 
 
 
 
395
  prompt_variables = {
396
  "conversation_history": history_text,
397
  "context": rag_chain.context,
398
  "question": query
399
  }
400
+ if "llama" in rag_chain.llm_choice.lower():
401
+ prompt_variables["context"] = truncate_prompt(prompt_variables["context"], max_tokens=4096)
402
  response = rag_chain.elevated_rag_chain.invoke(prompt_variables)
 
403
  rag_chain.conversation_history.append({"query": query, "response": response})
404
  input_token_count = count_tokens(query)
405
  output_token_count = count_tokens(response)
 
439
  # Gradio Interface Setup
440
  # ----------------------------
441
  custom_css = """
442
+ textarea {
443
+ overflow-y: scroll !important;
444
+ max-height: 200px;
 
 
445
  }
446
  """
447
 
 
453
  - ๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3
454
  - ๐Ÿ‡ช๐Ÿ‡บ Mistral-API
455
 
456
+ **๐Ÿ”ฅ Randomness (Temperature):** Adjusts output predictability.
457
 
458
+ **๐ŸŽฏ Word Variety (Topโ€‘p):** Limits word choices to a set probability percentage.
459
 
460
+ **โœ๏ธ Prompt Template:** Edit as desired.
461
 
462
+ **๐Ÿ”— File URLs:** Enter one URL per line (.pdf or .txt).
463
 
464
+ **โš–๏ธ BM25 Weight:** Adjust Lexical vs Semantics.
465
 
466
  **๐Ÿ” Query:** Enter your query below.
467
 
468
+ The response displays the model used, word count, and current context (with conversation history).
469
+ ''')
 
470
  with gr.Row():
471
  with gr.Column():
472
  model_dropdown = gr.Dropdown(
473
+ choices=["๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3", "๐Ÿ‡ช๐Ÿ‡บ Mistral-API"],
 
 
 
 
 
 
474
  value="๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
475
  label="Select Model"
476
  )
 
546
  inputs=[],
547
  outputs=[response_output, context_output, model_output]
548
  )
549
+
550
+ model_dropdown.change(
551
+ fn=update_model,
552
+ inputs=model_dropdown,
553
+ outputs=model_output
554
+ )
555
 
556
  if __name__ == "__main__":
557
  debug_print("Launching Gradio interface.")