alx-d commited on
Commit
8c804eb
·
verified ·
1 Parent(s): ded9c47

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +53 -41
advanced_rag.py CHANGED
@@ -196,20 +196,19 @@ class ElevatedRagChain:
196
  class MistralLLM(LLM):
197
  temperature: float = 0.7
198
  top_p: float = 0.95
199
- _client: Any = PrivateAttr() # Remove the default=None here
200
 
201
  def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
202
- # Initialize the private attributes before calling super().__init__
203
- self._client = Mistral(api_key=api_key)
204
- # Now call super().__init__
205
  super().__init__(temperature=temperature, top_p=top_p, **kwargs)
 
 
206
 
207
  @property
208
  def _llm_type(self) -> str:
209
  return "mistral_llm"
210
 
211
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
212
- response = self._client.chat.complete(
213
  model="mistral-small-latest",
214
  messages=[{"role": "user", "content": prompt}],
215
  temperature=self.temperature,
@@ -421,42 +420,55 @@ def submit_query_updated(query):
421
  debug_print("Inside submit_query function.")
422
  if not query:
423
  debug_print("Please enter a non-empty query")
424
- return "Please enter a non-empty query", "Word count: 0", f"Model used: {rag_chain.llm_choice}", ""
425
- if hasattr(rag_chain, 'elevated_rag_chain'):
426
- try:
427
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in rag_chain.conversation_history]) if rag_chain.conversation_history else ""
428
- prompt_variables = {
429
- "conversation_history": history_text,
430
- "context": rag_chain.context,
431
- "question": query
432
- }
433
- if "llama" in rag_chain.llm_choice.lower():
434
- prompt_variables["context"] = truncate_prompt(prompt_variables["context"], max_tokens=4092)
435
- response = rag_chain.elevated_rag_chain.invoke(prompt_variables)
436
- rag_chain.conversation_history.append({"query": query, "response": response})
437
- input_token_count = count_tokens(query)
438
- output_token_count = count_tokens(response)
439
- return (
440
- response,
441
- rag_chain.get_current_context(),
442
- f"Input tokens: {input_token_count}",
443
- f"Output tokens: {output_token_count}"
444
- )
445
- except Exception as e:
446
- error_msg = traceback.format_exc()
447
- debug_print("LLM error. Error: " + error_msg)
448
- return (
449
- "Query error: " + str(e),
450
- "",
451
- "Input tokens: 0",
452
- "Output tokens: 0"
453
- )
454
- return (
455
- "Please load files first.",
456
- "",
457
- "Input tokens: 0",
458
- "Output tokens: 0"
459
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  def reset_app_updated():
462
  global rag_chain
 
196
  class MistralLLM(LLM):
197
  temperature: float = 0.7
198
  top_p: float = 0.95
199
+ client: Any = None # Changed from _client PrivateAttr
200
 
201
  def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
 
 
 
202
  super().__init__(temperature=temperature, top_p=top_p, **kwargs)
203
+ # Initialize the client as a regular attribute instead of PrivateAttr
204
+ self.client = Mistral(api_key=api_key)
205
 
206
  @property
207
  def _llm_type(self) -> str:
208
  return "mistral_llm"
209
 
210
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
211
+ response = self.client.chat.complete( # Use self.client instead of self._client
212
  model="mistral-small-latest",
213
  messages=[{"role": "user", "content": prompt}],
214
  temperature=self.temperature,
 
420
  debug_print("Inside submit_query function.")
421
  if not query:
422
  debug_print("Please enter a non-empty query")
423
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
424
+
425
+ if not hasattr(rag_chain, 'elevated_rag_chain'):
426
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
427
+
428
+ try:
429
+ # Collect and truncate conversation history if needed
430
+ history_text = ""
431
+ if rag_chain.conversation_history:
432
+ # Only keep the last 3 conversations to limit context size
433
+ recent_history = rag_chain.conversation_history[-3:]
434
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}"
435
+ for conv in recent_history])
436
+
437
+ # Determine max context size based on model
438
+ max_context_tokens = 32000 if "mistral" in rag_chain.llm_choice.lower() else 4096
439
+ # Reserve 1000 tokens for the question and generation
440
+ max_context_tokens -= 1000
441
+
442
+ # Truncate context if needed
443
+ context = truncate_prompt(rag_chain.context, max_tokens=max_context_tokens)
444
+
445
+ prompt_variables = {
446
+ "conversation_history": history_text,
447
+ "context": context,
448
+ "question": query
449
+ }
450
+
451
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
452
+ rag_chain.conversation_history.append({"query": query, "response": response})
453
+
454
+ input_token_count = count_tokens(query)
455
+ output_token_count = count_tokens(response)
456
+
457
+ return (
458
+ response,
459
+ rag_chain.get_current_context(),
460
+ f"Input tokens: {input_token_count}",
461
+ f"Output tokens: {output_token_count}"
462
+ )
463
+ except Exception as e:
464
+ error_msg = traceback.format_exc()
465
+ debug_print("LLM error. Error: " + error_msg)
466
+ return (
467
+ f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
468
+ "",
469
+ "Input tokens: 0",
470
+ "Output tokens: 0"
471
+ )
472
 
473
  def reset_app_updated():
474
  global rag_chain