Upload folder using huggingface_hub
Browse files- advanced_rag.py +130 -138
advanced_rag.py
CHANGED
|
@@ -81,55 +81,27 @@ jobs = {} # Stores job status and results
|
|
| 81 |
results_queue = queue.Queue() # Thread-safe queue for completed jobs
|
| 82 |
processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
def get_job_list():
|
| 86 |
-
job_list_md = "### Submitted Jobs\n\n"
|
| 87 |
-
|
| 88 |
-
if not jobs:
|
| 89 |
-
return "No jobs found. Submit a query or load files to create jobs."
|
| 90 |
-
|
| 91 |
-
# Sort jobs by start time (newest first)
|
| 92 |
-
sorted_jobs = sorted(
|
| 93 |
-
[(job_id, job_info) for job_id, job_info in jobs.items()],
|
| 94 |
-
key=lambda x: x[1].get("start_time", 0),
|
| 95 |
-
reverse=True
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
for job_id, job_info in sorted_jobs:
|
| 99 |
-
status = job_info.get("status", "unknown")
|
| 100 |
-
job_type = job_info.get("type", "unknown")
|
| 101 |
-
query = job_info.get("query", "")
|
| 102 |
-
start_time = job_info.get("start_time", 0)
|
| 103 |
-
time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
| 104 |
-
|
| 105 |
-
# Create a shortened query preview
|
| 106 |
-
query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
|
| 107 |
-
|
| 108 |
-
# Create clickable links using Markdown
|
| 109 |
-
if job_type == "query":
|
| 110 |
-
job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
|
| 111 |
-
else:
|
| 112 |
-
job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
|
| 113 |
-
|
| 114 |
-
return job_list_md
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
try:
|
|
|
|
| 119 |
result = function(*args)
|
| 120 |
results_queue.put((job_id, result))
|
|
|
|
| 121 |
except Exception as e:
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
results_queue.put((job_id,
|
| 125 |
|
| 126 |
-
# Async version of load_pdfs_updated
|
| 127 |
def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
|
|
|
|
| 128 |
if not file_links:
|
| 129 |
-
return "Please enter non-empty URLs", "", "Model used: N/A"
|
| 130 |
|
| 131 |
job_id = str(uuid.uuid4())
|
| 132 |
-
debug_print(f"Starting async job {job_id} for loading
|
| 133 |
|
| 134 |
# Start background thread
|
| 135 |
threading.Thread(
|
|
@@ -138,41 +110,33 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
|
|
| 138 |
).start()
|
| 139 |
|
| 140 |
jobs[job_id] = {
|
| 141 |
-
"status": "processing",
|
| 142 |
"type": "load_files",
|
| 143 |
-
"start_time": time.time()
|
|
|
|
| 144 |
}
|
| 145 |
|
| 146 |
return (
|
| 147 |
-
f"Files
|
| 148 |
-
f"Use 'Check Job Status' with this ID to get results.",
|
| 149 |
f"Job ID: {job_id}",
|
| 150 |
-
f"Model
|
| 151 |
)
|
| 152 |
|
| 153 |
-
# Async version of submit_query_updated
|
| 154 |
def submit_query_async(query, model_choice=None):
|
|
|
|
| 155 |
if not query:
|
| 156 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 157 |
|
| 158 |
-
if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
|
| 159 |
-
return "Please load files first", "", "Input tokens: 0", "Output tokens: 0"
|
| 160 |
-
|
| 161 |
-
# Use the provided model if specified, otherwise use the current model
|
| 162 |
-
if model_choice and model_choice != "":
|
| 163 |
-
# Update the model temporarily for this query
|
| 164 |
-
current_model = rag_chain.llm_choice
|
| 165 |
-
rag_chain.update_llm_pipeline(
|
| 166 |
-
model_choice,
|
| 167 |
-
rag_chain.temperature,
|
| 168 |
-
rag_chain.top_p,
|
| 169 |
-
rag_chain.prompt_template,
|
| 170 |
-
rag_chain.bm25_weight
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
job_id = str(uuid.uuid4())
|
| 174 |
debug_print(f"Starting async job {job_id} for query: {query}")
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# Start background thread
|
| 177 |
threading.Thread(
|
| 178 |
target=process_in_background,
|
|
@@ -184,16 +148,48 @@ def submit_query_async(query, model_choice=None):
|
|
| 184 |
"type": "query",
|
| 185 |
"start_time": time.time(),
|
| 186 |
"query": query,
|
| 187 |
-
"model": rag_chain.llm_choice
|
| 188 |
}
|
| 189 |
|
| 190 |
return (
|
| 191 |
f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 192 |
-
f"Use 'Check Job Status' with this ID to get results.",
|
| 193 |
f"Job ID: {job_id}",
|
| 194 |
f"Input tokens: {count_tokens(query)}",
|
| 195 |
"Output tokens: pending"
|
| 196 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
# Function to handle job list clicks
|
| 199 |
def job_selected(job_id):
|
|
@@ -394,6 +390,7 @@ class ElevatedRagChain:
|
|
| 394 |
|
| 395 |
# Improve error handling in the ElevatedRagChain class
|
| 396 |
def create_llm_pipeline(self):
|
|
|
|
| 397 |
normalized = self.llm_choice.lower()
|
| 398 |
try:
|
| 399 |
if "remote" in normalized:
|
|
@@ -406,7 +403,7 @@ class ElevatedRagChain:
|
|
| 406 |
|
| 407 |
client = InferenceClient(token=hf_api_token, timeout=120)
|
| 408 |
|
| 409 |
-
|
| 410 |
def remote_generate(prompt: str) -> str:
|
| 411 |
max_retries = 3
|
| 412 |
backoff = 2 # start with 2 seconds
|
|
@@ -434,7 +431,7 @@ class ElevatedRagChain:
|
|
| 434 |
def _llm_type(self) -> str:
|
| 435 |
return "remote_llm"
|
| 436 |
|
| 437 |
-
def _call(self, prompt: str, stop:
|
| 438 |
return remote_generate(prompt)
|
| 439 |
|
| 440 |
@property
|
|
@@ -444,68 +441,57 @@ class ElevatedRagChain:
|
|
| 444 |
debug_print("Remote Meta-Llama-3 pipeline created successfully.")
|
| 445 |
return RemoteLLM()
|
| 446 |
|
| 447 |
-
elif "mistral" in normalized:
|
| 448 |
debug_print("Creating Mistral API pipeline...")
|
| 449 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 450 |
if not mistral_api_key:
|
| 451 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
| 452 |
-
|
| 453 |
-
# Import Mistral library with proper error handling
|
| 454 |
try:
|
| 455 |
from mistralai import Mistral
|
| 456 |
from mistralai.exceptions import MistralException
|
| 457 |
debug_print("Mistral library imported successfully")
|
| 458 |
except ImportError:
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
debug_print("Creating Mistral LLM instance")
|
| 496 |
-
mistral_llm = MistralLLM(
|
| 497 |
-
api_key=mistral_api_key,
|
| 498 |
-
temperature=self.temperature,
|
| 499 |
-
top_p=self.top_p
|
| 500 |
-
)
|
| 501 |
-
debug_print("Mistral API pipeline created successfully.")
|
| 502 |
-
return mistral_llm
|
| 503 |
|
| 504 |
else:
|
| 505 |
-
# Default case -
|
| 506 |
debug_print("Using local/fallback model pipeline")
|
| 507 |
-
model_id = "facebook/opt-350m" #
|
| 508 |
-
|
| 509 |
pipe = pipeline(
|
| 510 |
"text-generation",
|
| 511 |
model=model_id,
|
|
@@ -517,27 +503,21 @@ class ElevatedRagChain:
|
|
| 517 |
@property
|
| 518 |
def _llm_type(self) -> str:
|
| 519 |
return "local_llm"
|
| 520 |
-
|
| 521 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 522 |
-
#
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
return generated
|
| 530 |
-
except Exception as e:
|
| 531 |
-
debug_print(f"Generation error: {str(e)}")
|
| 532 |
-
return f"Error generating response: {str(e)}"
|
| 533 |
-
|
| 534 |
@property
|
| 535 |
def _identifying_params(self) -> dict:
|
| 536 |
-
return {"model": model_id}
|
| 537 |
|
| 538 |
debug_print("Local fallback pipeline created.")
|
| 539 |
return LocalLLM()
|
| 540 |
-
|
| 541 |
except Exception as e:
|
| 542 |
debug_print(f"Error creating LLM pipeline: {str(e)}")
|
| 543 |
# Return a dummy LLM that explains the error
|
|
@@ -546,7 +526,7 @@ class ElevatedRagChain:
|
|
| 546 |
def _llm_type(self) -> str:
|
| 547 |
return "error_llm"
|
| 548 |
|
| 549 |
-
def _call(self, prompt: str, stop:
|
| 550 |
return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
|
| 551 |
|
| 552 |
@property
|
|
@@ -555,6 +535,7 @@ class ElevatedRagChain:
|
|
| 555 |
|
| 556 |
return ErrorLLM()
|
| 557 |
|
|
|
|
| 558 |
def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
|
| 559 |
debug_print(f"Updating chain with new model: {new_model_choice}")
|
| 560 |
self.llm_choice = new_model_choice
|
|
@@ -624,7 +605,9 @@ class ElevatedRagChain:
|
|
| 624 |
"context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
|
| 625 |
"question": RunnableLambda(self.extract_question)
|
| 626 |
}) | self.capture_context
|
|
|
|
| 627 |
self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
|
|
|
|
| 628 |
self.str_output_parser = StrOutputParser()
|
| 629 |
debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
|
| 630 |
self.llm = self.create_llm_pipeline()
|
|
@@ -637,9 +620,10 @@ class ElevatedRagChain:
|
|
| 637 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
| 638 |
formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
|
| 639 |
return formatted
|
| 640 |
-
self.elevated_rag_chain = base_runnable |
|
| 641 |
debug_print("Elevated RAG chain successfully built and ready to use.")
|
| 642 |
|
|
|
|
| 643 |
def get_current_context(self) -> str:
|
| 644 |
base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
|
| 645 |
history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
|
|
@@ -917,6 +901,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 917 |
|
| 918 |
with gr.TabItem("Submit Query"):
|
| 919 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
query_input = gr.Textbox(
|
| 921 |
label="Enter your query here",
|
| 922 |
placeholder="Type your query",
|
|
@@ -1007,6 +998,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1007 |
outputs=[load_response, load_context, model_output]
|
| 1008 |
)
|
| 1009 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
submit_button.click(
|
| 1011 |
submit_query_async,
|
| 1012 |
inputs=[query_input, query_model_dropdown],
|
|
@@ -1044,19 +1042,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
|
|
| 1044 |
outputs=[reset_response, reset_context, reset_model]
|
| 1045 |
)
|
| 1046 |
|
|
|
|
| 1047 |
model_dropdown.change(
|
| 1048 |
fn=sync_model_dropdown,
|
| 1049 |
inputs=model_dropdown,
|
| 1050 |
outputs=query_model_dropdown
|
| 1051 |
)
|
| 1052 |
|
| 1053 |
-
# Also sync in the other direction
|
| 1054 |
-
query_model_dropdown.change(
|
| 1055 |
-
fn=sync_model_dropdown,
|
| 1056 |
-
inputs=query_model_dropdown,
|
| 1057 |
-
outputs=model_dropdown
|
| 1058 |
-
)
|
| 1059 |
-
|
| 1060 |
# Add an event to refresh the job list on page load
|
| 1061 |
app.load(
|
| 1062 |
fn=refresh_job_list,
|
|
|
|
| 81 |
results_queue = queue.Queue() # Thread-safe queue for completed jobs
|
| 82 |
processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
|
| 83 |
|
| 84 |
+
# Add these missing async processing functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
def process_in_background(job_id, function, args):
|
| 87 |
+
"""Process a function in the background and store results"""
|
| 88 |
try:
|
| 89 |
+
debug_print(f"Processing job {job_id} in background")
|
| 90 |
result = function(*args)
|
| 91 |
results_queue.put((job_id, result))
|
| 92 |
+
debug_print(f"Job {job_id} completed and added to results queue")
|
| 93 |
except Exception as e:
|
| 94 |
+
debug_print(f"Error in background job {job_id}: {str(e)}")
|
| 95 |
+
error_result = (f"Error processing job: {str(e)}", "", "", "")
|
| 96 |
+
results_queue.put((job_id, error_result))
|
| 97 |
|
|
|
|
| 98 |
def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
|
| 99 |
+
"""Asynchronous version of load_pdfs_updated to prevent timeouts"""
|
| 100 |
if not file_links:
|
| 101 |
+
return "Please enter non-empty URLs", "", "Model used: N/A"
|
| 102 |
|
| 103 |
job_id = str(uuid.uuid4())
|
| 104 |
+
debug_print(f"Starting async job {job_id} for file loading")
|
| 105 |
|
| 106 |
# Start background thread
|
| 107 |
threading.Thread(
|
|
|
|
| 110 |
).start()
|
| 111 |
|
| 112 |
jobs[job_id] = {
|
| 113 |
+
"status": "processing",
|
| 114 |
"type": "load_files",
|
| 115 |
+
"start_time": time.time(),
|
| 116 |
+
"query": f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
|
| 117 |
}
|
| 118 |
|
| 119 |
return (
|
| 120 |
+
f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 121 |
+
f"Use 'Check Job Status' tab with this ID to get results.",
|
| 122 |
f"Job ID: {job_id}",
|
| 123 |
+
f"Model requested: {model_choice}"
|
| 124 |
)
|
| 125 |
|
|
|
|
| 126 |
def submit_query_async(query, model_choice=None):
|
| 127 |
+
"""Asynchronous version of submit_query_updated to prevent timeouts"""
|
| 128 |
if not query:
|
| 129 |
return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
job_id = str(uuid.uuid4())
|
| 132 |
debug_print(f"Starting async job {job_id} for query: {query}")
|
| 133 |
|
| 134 |
+
# Update model if specified
|
| 135 |
+
if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
|
| 136 |
+
debug_print(f"Updating model to {model_choice} for this query")
|
| 137 |
+
rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
|
| 138 |
+
rag_chain.prompt_template, rag_chain.bm25_weight)
|
| 139 |
+
|
| 140 |
# Start background thread
|
| 141 |
threading.Thread(
|
| 142 |
target=process_in_background,
|
|
|
|
| 148 |
"type": "query",
|
| 149 |
"start_time": time.time(),
|
| 150 |
"query": query,
|
| 151 |
+
"model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
|
| 152 |
}
|
| 153 |
|
| 154 |
return (
|
| 155 |
f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
|
| 156 |
+
f"Use 'Check Job Status' tab with this ID to get results.",
|
| 157 |
f"Job ID: {job_id}",
|
| 158 |
f"Input tokens: {count_tokens(query)}",
|
| 159 |
"Output tokens: pending"
|
| 160 |
)
|
| 161 |
+
|
| 162 |
+
# Function to display all jobs as a clickable list
|
| 163 |
+
def get_job_list():
|
| 164 |
+
job_list_md = "### Submitted Jobs\n\n"
|
| 165 |
+
|
| 166 |
+
if not jobs:
|
| 167 |
+
return "No jobs found. Submit a query or load files to create jobs."
|
| 168 |
+
|
| 169 |
+
# Sort jobs by start time (newest first)
|
| 170 |
+
sorted_jobs = sorted(
|
| 171 |
+
[(job_id, job_info) for job_id, job_info in jobs.items()],
|
| 172 |
+
key=lambda x: x[1].get("start_time", 0),
|
| 173 |
+
reverse=True
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
for job_id, job_info in sorted_jobs:
|
| 177 |
+
status = job_info.get("status", "unknown")
|
| 178 |
+
job_type = job_info.get("type", "unknown")
|
| 179 |
+
query = job_info.get("query", "")
|
| 180 |
+
start_time = job_info.get("start_time", 0)
|
| 181 |
+
time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
| 182 |
+
|
| 183 |
+
# Create a shortened query preview
|
| 184 |
+
query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
|
| 185 |
+
|
| 186 |
+
# Create clickable links using Markdown
|
| 187 |
+
if job_type == "query":
|
| 188 |
+
job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
|
| 189 |
+
else:
|
| 190 |
+
job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
|
| 191 |
+
|
| 192 |
+
return job_list_md
|
| 193 |
|
| 194 |
# Function to handle job list clicks
|
| 195 |
def job_selected(job_id):
|
|
|
|
| 390 |
|
| 391 |
# Improve error handling in the ElevatedRagChain class
|
| 392 |
def create_llm_pipeline(self):
|
| 393 |
+
from langchain.llms.base import LLM # Import LLM here so it's always defined
|
| 394 |
normalized = self.llm_choice.lower()
|
| 395 |
try:
|
| 396 |
if "remote" in normalized:
|
|
|
|
| 403 |
|
| 404 |
client = InferenceClient(token=hf_api_token, timeout=120)
|
| 405 |
|
| 406 |
+
# We no longer use wait_for_model because it's unsupported
|
| 407 |
def remote_generate(prompt: str) -> str:
|
| 408 |
max_retries = 3
|
| 409 |
backoff = 2 # start with 2 seconds
|
|
|
|
| 431 |
def _llm_type(self) -> str:
|
| 432 |
return "remote_llm"
|
| 433 |
|
| 434 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 435 |
return remote_generate(prompt)
|
| 436 |
|
| 437 |
@property
|
|
|
|
| 441 |
debug_print("Remote Meta-Llama-3 pipeline created successfully.")
|
| 442 |
return RemoteLLM()
|
| 443 |
|
| 444 |
+
elif "mistral-api" in normalized:
|
| 445 |
debug_print("Creating Mistral API pipeline...")
|
| 446 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 447 |
if not mistral_api_key:
|
| 448 |
raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
|
|
|
|
|
|
|
| 449 |
try:
|
| 450 |
from mistralai import Mistral
|
| 451 |
from mistralai.exceptions import MistralException
|
| 452 |
debug_print("Mistral library imported successfully")
|
| 453 |
except ImportError:
|
| 454 |
+
debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
|
| 455 |
+
normalized = "llama"
|
| 456 |
+
if normalized != "llama":
|
| 457 |
+
class MistralLLM(LLM):
|
| 458 |
+
temperature: float = 0.7
|
| 459 |
+
top_p: float = 0.95
|
| 460 |
+
_client: Any = PrivateAttr(default=None)
|
| 461 |
+
def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
|
| 462 |
+
super().__init__(**kwargs)
|
| 463 |
+
self._client = Mistral(api_key=api_key)
|
| 464 |
+
self.temperature = temperature
|
| 465 |
+
self.top_p = top_p
|
| 466 |
+
@property
|
| 467 |
+
def _llm_type(self) -> str:
|
| 468 |
+
return "mistral_llm"
|
| 469 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 470 |
+
try:
|
| 471 |
+
debug_print("Calling Mistral API...")
|
| 472 |
+
response = self._client.chat.complete(
|
| 473 |
+
model="mistral-small-latest",
|
| 474 |
+
messages=[{"role": "user", "content": prompt}],
|
| 475 |
+
temperature=self.temperature,
|
| 476 |
+
top_p=self.top_p,
|
| 477 |
+
max_tokens=32000
|
| 478 |
+
)
|
| 479 |
+
return response.choices[0].message.content
|
| 480 |
+
except Exception as e:
|
| 481 |
+
debug_print(f"Mistral API error: {str(e)}")
|
| 482 |
+
return f"Error generating response: {str(e)}"
|
| 483 |
+
@property
|
| 484 |
+
def _identifying_params(self) -> dict:
|
| 485 |
+
return {"model": "mistral-small-latest"}
|
| 486 |
+
debug_print("Creating Mistral LLM instance")
|
| 487 |
+
mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
|
| 488 |
+
debug_print("Mistral API pipeline created successfully.")
|
| 489 |
+
return mistral_llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
else:
|
| 492 |
+
# Default case - using a fallback model (or Llama)
|
| 493 |
debug_print("Using local/fallback model pipeline")
|
| 494 |
+
model_id = "facebook/opt-350m" # Use a smaller model as fallback
|
|
|
|
| 495 |
pipe = pipeline(
|
| 496 |
"text-generation",
|
| 497 |
model=model_id,
|
|
|
|
| 503 |
@property
|
| 504 |
def _llm_type(self) -> str:
|
| 505 |
return "local_llm"
|
|
|
|
| 506 |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 507 |
+
# For this fallback, truncate prompt if it exceeds limits
|
| 508 |
+
reserved_gen = 128
|
| 509 |
+
max_total = 1024
|
| 510 |
+
max_prompt_tokens = max_total - reserved_gen
|
| 511 |
+
truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
|
| 512 |
+
generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
|
| 513 |
+
return generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
@property
|
| 515 |
def _identifying_params(self) -> dict:
|
| 516 |
+
return {"model": model_id, "max_length": 1024}
|
| 517 |
|
| 518 |
debug_print("Local fallback pipeline created.")
|
| 519 |
return LocalLLM()
|
| 520 |
+
|
| 521 |
except Exception as e:
|
| 522 |
debug_print(f"Error creating LLM pipeline: {str(e)}")
|
| 523 |
# Return a dummy LLM that explains the error
|
|
|
|
| 526 |
def _llm_type(self) -> str:
|
| 527 |
return "error_llm"
|
| 528 |
|
| 529 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
| 530 |
return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
|
| 531 |
|
| 532 |
@property
|
|
|
|
| 535 |
|
| 536 |
return ErrorLLM()
|
| 537 |
|
| 538 |
+
|
| 539 |
def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
|
| 540 |
debug_print(f"Updating chain with new model: {new_model_choice}")
|
| 541 |
self.llm_choice = new_model_choice
|
|
|
|
| 605 |
"context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
|
| 606 |
"question": RunnableLambda(self.extract_question)
|
| 607 |
}) | self.capture_context
|
| 608 |
+
# Wrap the prompt template in a RunnableLambda
|
| 609 |
self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
|
| 610 |
+
prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
|
| 611 |
self.str_output_parser = StrOutputParser()
|
| 612 |
debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
|
| 613 |
self.llm = self.create_llm_pipeline()
|
|
|
|
| 620 |
formatted += f"- **Generated using:** {self.llm_choice}\n"
|
| 621 |
formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
|
| 622 |
return formatted
|
| 623 |
+
self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
|
| 624 |
debug_print("Elevated RAG chain successfully built and ready to use.")
|
| 625 |
|
| 626 |
+
|
| 627 |
def get_current_context(self) -> str:
|
| 628 |
base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
|
| 629 |
history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
|
|
|
|
| 901 |
|
| 902 |
with gr.TabItem("Submit Query"):
|
| 903 |
with gr.Row():
|
| 904 |
+
# Add this line to define the query_model_dropdown
|
| 905 |
+
query_model_dropdown = gr.Dropdown(
|
| 906 |
+
choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
|
| 907 |
+
value="🇺🇸 Remote Meta-Llama-3",
|
| 908 |
+
label="Query Model"
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
query_input = gr.Textbox(
|
| 912 |
label="Enter your query here",
|
| 913 |
placeholder="Type your query",
|
|
|
|
| 998 |
outputs=[load_response, load_context, model_output]
|
| 999 |
)
|
| 1000 |
|
| 1001 |
+
# Also sync in the other direction
|
| 1002 |
+
query_model_dropdown.change(
|
| 1003 |
+
fn=sync_model_dropdown,
|
| 1004 |
+
inputs=query_model_dropdown,
|
| 1005 |
+
outputs=model_dropdown
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
submit_button.click(
|
| 1009 |
submit_query_async,
|
| 1010 |
inputs=[query_input, query_model_dropdown],
|
|
|
|
| 1042 |
outputs=[reset_response, reset_context, reset_model]
|
| 1043 |
)
|
| 1044 |
|
| 1045 |
+
|
| 1046 |
model_dropdown.change(
|
| 1047 |
fn=sync_model_dropdown,
|
| 1048 |
inputs=model_dropdown,
|
| 1049 |
outputs=query_model_dropdown
|
| 1050 |
)
|
| 1051 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
# Add an event to refresh the job list on page load
|
| 1053 |
app.load(
|
| 1054 |
fn=refresh_job_list,
|