Upload folder using huggingface_hub
Browse files- advanced_rag.py +53 -41
advanced_rag.py
CHANGED
|
@@ -196,20 +196,19 @@ class ElevatedRagChain:
|
|
| 196 |
class MistralLLM(LLM):
|
| 197 |
temperature: float = 0.7
|
| 198 |
top_p: float = 0.95
|
| 199 |
-
|
| 200 |
|
| 201 |
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
| 202 |
-
# Initialize the private attributes before calling super().__init__
|
| 203 |
-
self._client = Mistral(api_key=api_key)
|
| 204 |
-
# Now call super().__init__
|
| 205 |
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
|
|
|
|
|
|
| 206 |
|
| 207 |
@property
|
| 208 |
def _llm_type(self) -> str:
|
| 209 |
return "mistral_llm"
|
| 210 |
|
| 211 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 212 |
-
response = self.
|
| 213 |
model="mistral-small-latest",
|
| 214 |
messages=[{"role": "user", "content": prompt}],
|
| 215 |
temperature=self.temperature,
|
|
@@ -421,42 +420,55 @@ def submit_query_updated(query):
|
|
| 421 |
debug_print("Inside submit_query function.")
|
| 422 |
if not query:
|
| 423 |
debug_print("Please enter a non-empty query")
|
| 424 |
-
return "Please enter a non-empty query", "
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
def reset_app_updated():
|
| 462 |
global rag_chain
|
|
|
|
| 196 |
class MistralLLM(LLM):
|
| 197 |
temperature: float = 0.7
|
| 198 |
top_p: float = 0.95
|
| 199 |
+
client: Any = None # Changed from _client PrivateAttr
|
| 200 |
|
| 201 |
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
|
|
|
|
|
|
|
|
|
| 202 |
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 203 |
+
# Initialize the client as a regular attribute instead of PrivateAttr
|
| 204 |
+
self.client = Mistral(api_key=api_key)
|
| 205 |
|
| 206 |
@property
|
| 207 |
def _llm_type(self) -> str:
|
| 208 |
return "mistral_llm"
|
| 209 |
|
| 210 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 211 |
+
response = self.client.chat.complete( # Use self.client instead of self._client
|
| 212 |
model="mistral-small-latest",
|
| 213 |
messages=[{"role": "user", "content": prompt}],
|
| 214 |
temperature=self.temperature,
|
|
|
|
| 420 |
debug_print("Inside submit_query function.")
|
| 421 |
if not query:
|
| 422 |
debug_print("Please enter a non-empty query")
|
| 423 |
+
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 424 |
+
|
| 425 |
+
if not hasattr(rag_chain, 'elevated_rag_chain'):
|
| 426 |
+
return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
# Collect and truncate conversation history if needed
|
| 430 |
+
history_text = ""
|
| 431 |
+
if rag_chain.conversation_history:
|
| 432 |
+
# Only keep the last 3 conversations to limit context size
|
| 433 |
+
recent_history = rag_chain.conversation_history[-3:]
|
| 434 |
+
history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}"
|
| 435 |
+
for conv in recent_history])
|
| 436 |
+
|
| 437 |
+
# Determine max context size based on model
|
| 438 |
+
max_context_tokens = 32000 if "mistral" in rag_chain.llm_choice.lower() else 4096
|
| 439 |
+
# Reserve 1000 tokens for the question and generation
|
| 440 |
+
max_context_tokens -= 1000
|
| 441 |
+
|
| 442 |
+
# Truncate context if needed
|
| 443 |
+
context = truncate_prompt(rag_chain.context, max_tokens=max_context_tokens)
|
| 444 |
+
|
| 445 |
+
prompt_variables = {
|
| 446 |
+
"conversation_history": history_text,
|
| 447 |
+
"context": context,
|
| 448 |
+
"question": query
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
response = rag_chain.elevated_rag_chain.invoke({"question": query})
|
| 452 |
+
rag_chain.conversation_history.append({"query": query, "response": response})
|
| 453 |
+
|
| 454 |
+
input_token_count = count_tokens(query)
|
| 455 |
+
output_token_count = count_tokens(response)
|
| 456 |
+
|
| 457 |
+
return (
|
| 458 |
+
response,
|
| 459 |
+
rag_chain.get_current_context(),
|
| 460 |
+
f"Input tokens: {input_token_count}",
|
| 461 |
+
f"Output tokens: {output_token_count}"
|
| 462 |
+
)
|
| 463 |
+
except Exception as e:
|
| 464 |
+
error_msg = traceback.format_exc()
|
| 465 |
+
debug_print("LLM error. Error: " + error_msg)
|
| 466 |
+
return (
|
| 467 |
+
f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
|
| 468 |
+
"",
|
| 469 |
+
"Input tokens: 0",
|
| 470 |
+
"Output tokens: 0"
|
| 471 |
+
)
|
| 472 |
|
| 473 |
def reset_app_updated():
|
| 474 |
global rag_chain
|