alx-d commited on
Commit
23b48b8
·
verified ·
1 Parent(s): 4a725a5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +130 -138
advanced_rag.py CHANGED
@@ -81,55 +81,27 @@ jobs = {} # Stores job status and results
81
  results_queue = queue.Queue() # Thread-safe queue for completed jobs
82
  processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
83
 
84
- # Function to display all jobs as a clickable list
85
- def get_job_list():
86
- job_list_md = "### Submitted Jobs\n\n"
87
-
88
- if not jobs:
89
- return "No jobs found. Submit a query or load files to create jobs."
90
-
91
- # Sort jobs by start time (newest first)
92
- sorted_jobs = sorted(
93
- [(job_id, job_info) for job_id, job_info in jobs.items()],
94
- key=lambda x: x[1].get("start_time", 0),
95
- reverse=True
96
- )
97
-
98
- for job_id, job_info in sorted_jobs:
99
- status = job_info.get("status", "unknown")
100
- job_type = job_info.get("type", "unknown")
101
- query = job_info.get("query", "")
102
- start_time = job_info.get("start_time", 0)
103
- time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
104
-
105
- # Create a shortened query preview
106
- query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
107
-
108
- # Create clickable links using Markdown
109
- if job_type == "query":
110
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
111
- else:
112
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
113
-
114
- return job_list_md
115
 
116
- # Function to process tasks in background
117
- def process_in_background(job_id: str, function, args):
118
  try:
 
119
  result = function(*args)
120
  results_queue.put((job_id, result))
 
121
  except Exception as e:
122
- error_msg = f"Error: {str(e)}\n\nTraceback: {traceback.format_exc()}"
123
- debug_print(f"Job {job_id} failed: {error_msg}")
124
- results_queue.put((job_id, (error_msg, "", "", "")))
125
 
126
- # Async version of load_pdfs_updated
127
  def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
 
128
  if not file_links:
129
- return "Please enter non-empty URLs", "", "Model used: N/A", "Context: N/A"
130
 
131
  job_id = str(uuid.uuid4())
132
- debug_print(f"Starting async job {job_id} for loading files")
133
 
134
  # Start background thread
135
  threading.Thread(
@@ -138,41 +110,33 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp
138
  ).start()
139
 
140
  jobs[job_id] = {
141
- "status": "processing",
142
  "type": "load_files",
143
- "start_time": time.time()
 
144
  }
145
 
146
  return (
147
- f"Files are being processed in the background (Job ID: {job_id}).\n\n"
148
- f"Use 'Check Job Status' with this ID to get results.",
149
  f"Job ID: {job_id}",
150
- f"Model selected: {model_choice}"
151
  )
152
 
153
- # Async version of submit_query_updated
154
  def submit_query_async(query, model_choice=None):
 
155
  if not query:
156
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
157
 
158
- if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
159
- return "Please load files first", "", "Input tokens: 0", "Output tokens: 0"
160
-
161
- # Use the provided model if specified, otherwise use the current model
162
- if model_choice and model_choice != "":
163
- # Update the model temporarily for this query
164
- current_model = rag_chain.llm_choice
165
- rag_chain.update_llm_pipeline(
166
- model_choice,
167
- rag_chain.temperature,
168
- rag_chain.top_p,
169
- rag_chain.prompt_template,
170
- rag_chain.bm25_weight
171
- )
172
-
173
  job_id = str(uuid.uuid4())
174
  debug_print(f"Starting async job {job_id} for query: {query}")
175
 
 
 
 
 
 
 
176
  # Start background thread
177
  threading.Thread(
178
  target=process_in_background,
@@ -184,16 +148,48 @@ def submit_query_async(query, model_choice=None):
184
  "type": "query",
185
  "start_time": time.time(),
186
  "query": query,
187
- "model": rag_chain.llm_choice # Store which model is being used
188
  }
189
 
190
  return (
191
  f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
192
- f"Use 'Check Job Status' with this ID to get results.",
193
  f"Job ID: {job_id}",
194
  f"Input tokens: {count_tokens(query)}",
195
  "Output tokens: pending"
196
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  # Function to handle job list clicks
199
  def job_selected(job_id):
@@ -394,6 +390,7 @@ class ElevatedRagChain:
394
 
395
  # Improve error handling in the ElevatedRagChain class
396
  def create_llm_pipeline(self):
 
397
  normalized = self.llm_choice.lower()
398
  try:
399
  if "remote" in normalized:
@@ -406,7 +403,7 @@ class ElevatedRagChain:
406
 
407
  client = InferenceClient(token=hf_api_token, timeout=120)
408
 
409
- from huggingface_hub.utils._errors import HfHubHTTPError
410
  def remote_generate(prompt: str) -> str:
411
  max_retries = 3
412
  backoff = 2 # start with 2 seconds
@@ -434,7 +431,7 @@ class ElevatedRagChain:
434
  def _llm_type(self) -> str:
435
  return "remote_llm"
436
 
437
- def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
438
  return remote_generate(prompt)
439
 
440
  @property
@@ -444,68 +441,57 @@ class ElevatedRagChain:
444
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
445
  return RemoteLLM()
446
 
447
- elif "mistral" in normalized:
448
  debug_print("Creating Mistral API pipeline...")
449
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
450
  if not mistral_api_key:
451
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
452
-
453
- # Import Mistral library with proper error handling
454
  try:
455
  from mistralai import Mistral
456
  from mistralai.exceptions import MistralException
457
  debug_print("Mistral library imported successfully")
458
  except ImportError:
459
- raise ImportError("Mistral client library not found. Install with: pip install mistralai")
460
-
461
- # Fixed MistralLLM implementation that works with Pydantic v1
462
- class MistralLLM(LLM):
463
- client: Optional[Any] = None
464
- temperature: float = 0.7
465
- top_p: float = 0.95
466
-
467
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
468
- super().__init__(temperature=temperature, top_p=top_p, **kwargs)
469
- self.client = Mistral(api_key=api_key)
470
- debug_print("Mistral client initialized")
471
-
472
- @property
473
- def _llm_type(self) -> str:
474
- return "mistral_llm"
475
-
476
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
477
- try:
478
- debug_print("Calling Mistral API...")
479
- response = self.client.chat.complete(
480
- model="mistral-small-latest",
481
- messages=[{"role": "user", "content": prompt}],
482
- temperature=self.temperature,
483
- top_p=self.top_p,
484
- max_tokens=1024 # Limit token count for faster response
485
- )
486
- return response.choices[0].message.content
487
- except Exception as e:
488
- debug_print(f"Mistral API error: {str(e)}")
489
- return f"Error generating response: {str(e)}"
490
-
491
- @property
492
- def _identifying_params(self) -> dict:
493
- return {"model": "mistral-small-latest"}
494
-
495
- debug_print("Creating Mistral LLM instance")
496
- mistral_llm = MistralLLM(
497
- api_key=mistral_api_key,
498
- temperature=self.temperature,
499
- top_p=self.top_p
500
- )
501
- debug_print("Mistral API pipeline created successfully.")
502
- return mistral_llm
503
 
504
  else:
505
- # Default case - use a smaller model that's more likely to work within constraints
506
  debug_print("Using local/fallback model pipeline")
507
- model_id = "facebook/opt-350m" # Much smaller model
508
-
509
  pipe = pipeline(
510
  "text-generation",
511
  model=model_id,
@@ -517,27 +503,21 @@ class ElevatedRagChain:
517
  @property
518
  def _llm_type(self) -> str:
519
  return "local_llm"
520
-
521
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
522
- # Aggressively truncate prompt
523
- truncated_prompt = truncate_prompt(prompt, max_tokens=512)
524
- try:
525
- generated = pipe(truncated_prompt, max_new_tokens=256)[0]["generated_text"]
526
- # Only return the newly generated part
527
- if generated.startswith(truncated_prompt):
528
- return generated[len(truncated_prompt):].strip()
529
- return generated
530
- except Exception as e:
531
- debug_print(f"Generation error: {str(e)}")
532
- return f"Error generating response: {str(e)}"
533
-
534
  @property
535
  def _identifying_params(self) -> dict:
536
- return {"model": model_id}
537
 
538
  debug_print("Local fallback pipeline created.")
539
  return LocalLLM()
540
-
541
  except Exception as e:
542
  debug_print(f"Error creating LLM pipeline: {str(e)}")
543
  # Return a dummy LLM that explains the error
@@ -546,7 +526,7 @@ class ElevatedRagChain:
546
  def _llm_type(self) -> str:
547
  return "error_llm"
548
 
549
- def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
550
  return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
551
 
552
  @property
@@ -555,6 +535,7 @@ class ElevatedRagChain:
555
 
556
  return ErrorLLM()
557
 
 
558
  def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
559
  debug_print(f"Updating chain with new model: {new_model_choice}")
560
  self.llm_choice = new_model_choice
@@ -624,7 +605,9 @@ class ElevatedRagChain:
624
  "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
625
  "question": RunnableLambda(self.extract_question)
626
  }) | self.capture_context
 
627
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
 
628
  self.str_output_parser = StrOutputParser()
629
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
630
  self.llm = self.create_llm_pipeline()
@@ -637,9 +620,10 @@ class ElevatedRagChain:
637
  formatted += f"- **Generated using:** {self.llm_choice}\n"
638
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
639
  return formatted
640
- self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
641
  debug_print("Elevated RAG chain successfully built and ready to use.")
642
 
 
643
  def get_current_context(self) -> str:
644
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
645
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
@@ -917,6 +901,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
917
 
918
  with gr.TabItem("Submit Query"):
919
  with gr.Row():
 
 
 
 
 
 
 
920
  query_input = gr.Textbox(
921
  label="Enter your query here",
922
  placeholder="Type your query",
@@ -1007,6 +998,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1007
  outputs=[load_response, load_context, model_output]
1008
  )
1009
 
 
 
 
 
 
 
 
1010
  submit_button.click(
1011
  submit_query_async,
1012
  inputs=[query_input, query_model_dropdown],
@@ -1044,19 +1042,13 @@ https://www.gutenberg.org/ebooks/8438.txt.utf-8
1044
  outputs=[reset_response, reset_context, reset_model]
1045
  )
1046
 
 
1047
  model_dropdown.change(
1048
  fn=sync_model_dropdown,
1049
  inputs=model_dropdown,
1050
  outputs=query_model_dropdown
1051
  )
1052
 
1053
- # Also sync in the other direction
1054
- query_model_dropdown.change(
1055
- fn=sync_model_dropdown,
1056
- inputs=query_model_dropdown,
1057
- outputs=model_dropdown
1058
- )
1059
-
1060
  # Add an event to refresh the job list on page load
1061
  app.load(
1062
  fn=refresh_job_list,
 
81
  results_queue = queue.Queue() # Thread-safe queue for completed jobs
82
  processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
83
 
84
+ # Add these missing async processing functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ def process_in_background(job_id, function, args):
87
+ """Process a function in the background and store results"""
88
  try:
89
+ debug_print(f"Processing job {job_id} in background")
90
  result = function(*args)
91
  results_queue.put((job_id, result))
92
+ debug_print(f"Job {job_id} completed and added to results queue")
93
  except Exception as e:
94
+ debug_print(f"Error in background job {job_id}: {str(e)}")
95
+ error_result = (f"Error processing job: {str(e)}", "", "", "")
96
+ results_queue.put((job_id, error_result))
97
 
 
98
  def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
99
+ """Asynchronous version of load_pdfs_updated to prevent timeouts"""
100
  if not file_links:
101
+ return "Please enter non-empty URLs", "", "Model used: N/A"
102
 
103
  job_id = str(uuid.uuid4())
104
+ debug_print(f"Starting async job {job_id} for file loading")
105
 
106
  # Start background thread
107
  threading.Thread(
 
110
  ).start()
111
 
112
  jobs[job_id] = {
113
+ "status": "processing",
114
  "type": "load_files",
115
+ "start_time": time.time(),
116
+ "query": f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
117
  }
118
 
119
  return (
120
+ f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
121
+ f"Use 'Check Job Status' tab with this ID to get results.",
122
  f"Job ID: {job_id}",
123
+ f"Model requested: {model_choice}"
124
  )
125
 
 
126
  def submit_query_async(query, model_choice=None):
127
+ """Asynchronous version of submit_query_updated to prevent timeouts"""
128
  if not query:
129
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  job_id = str(uuid.uuid4())
132
  debug_print(f"Starting async job {job_id} for query: {query}")
133
 
134
+ # Update model if specified
135
+ if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
136
+ debug_print(f"Updating model to {model_choice} for this query")
137
+ rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
138
+ rag_chain.prompt_template, rag_chain.bm25_weight)
139
+
140
  # Start background thread
141
  threading.Thread(
142
  target=process_in_background,
 
148
  "type": "query",
149
  "start_time": time.time(),
150
  "query": query,
151
+ "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
152
  }
153
 
154
  return (
155
  f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
156
+ f"Use 'Check Job Status' tab with this ID to get results.",
157
  f"Job ID: {job_id}",
158
  f"Input tokens: {count_tokens(query)}",
159
  "Output tokens: pending"
160
  )
161
+
162
+ # Function to display all jobs as a clickable list
163
+ def get_job_list():
164
+ job_list_md = "### Submitted Jobs\n\n"
165
+
166
+ if not jobs:
167
+ return "No jobs found. Submit a query or load files to create jobs."
168
+
169
+ # Sort jobs by start time (newest first)
170
+ sorted_jobs = sorted(
171
+ [(job_id, job_info) for job_id, job_info in jobs.items()],
172
+ key=lambda x: x[1].get("start_time", 0),
173
+ reverse=True
174
+ )
175
+
176
+ for job_id, job_info in sorted_jobs:
177
+ status = job_info.get("status", "unknown")
178
+ job_type = job_info.get("type", "unknown")
179
+ query = job_info.get("query", "")
180
+ start_time = job_info.get("start_time", 0)
181
+ time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
182
+
183
+ # Create a shortened query preview
184
+ query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
185
+
186
+ # Create clickable links using Markdown
187
+ if job_type == "query":
188
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
189
+ else:
190
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
191
+
192
+ return job_list_md
193
 
194
  # Function to handle job list clicks
195
  def job_selected(job_id):
 
390
 
391
  # Improve error handling in the ElevatedRagChain class
392
  def create_llm_pipeline(self):
393
+ from langchain.llms.base import LLM # Import LLM here so it's always defined
394
  normalized = self.llm_choice.lower()
395
  try:
396
  if "remote" in normalized:
 
403
 
404
  client = InferenceClient(token=hf_api_token, timeout=120)
405
 
406
+ # We no longer use wait_for_model because it's unsupported
407
  def remote_generate(prompt: str) -> str:
408
  max_retries = 3
409
  backoff = 2 # start with 2 seconds
 
431
  def _llm_type(self) -> str:
432
  return "remote_llm"
433
 
434
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
435
  return remote_generate(prompt)
436
 
437
  @property
 
441
  debug_print("Remote Meta-Llama-3 pipeline created successfully.")
442
  return RemoteLLM()
443
 
444
+ elif "mistral-api" in normalized:
445
  debug_print("Creating Mistral API pipeline...")
446
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
447
  if not mistral_api_key:
448
  raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
 
 
449
  try:
450
  from mistralai import Mistral
451
  from mistralai.exceptions import MistralException
452
  debug_print("Mistral library imported successfully")
453
  except ImportError:
454
+ debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
455
+ normalized = "llama"
456
+ if normalized != "llama":
457
+ class MistralLLM(LLM):
458
+ temperature: float = 0.7
459
+ top_p: float = 0.95
460
+ _client: Any = PrivateAttr(default=None)
461
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
462
+ super().__init__(**kwargs)
463
+ self._client = Mistral(api_key=api_key)
464
+ self.temperature = temperature
465
+ self.top_p = top_p
466
+ @property
467
+ def _llm_type(self) -> str:
468
+ return "mistral_llm"
469
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
470
+ try:
471
+ debug_print("Calling Mistral API...")
472
+ response = self._client.chat.complete(
473
+ model="mistral-small-latest",
474
+ messages=[{"role": "user", "content": prompt}],
475
+ temperature=self.temperature,
476
+ top_p=self.top_p,
477
+ max_tokens=32000
478
+ )
479
+ return response.choices[0].message.content
480
+ except Exception as e:
481
+ debug_print(f"Mistral API error: {str(e)}")
482
+ return f"Error generating response: {str(e)}"
483
+ @property
484
+ def _identifying_params(self) -> dict:
485
+ return {"model": "mistral-small-latest"}
486
+ debug_print("Creating Mistral LLM instance")
487
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
488
+ debug_print("Mistral API pipeline created successfully.")
489
+ return mistral_llm
 
 
 
 
 
 
 
 
490
 
491
  else:
492
+ # Default case - using a fallback model (or Llama)
493
  debug_print("Using local/fallback model pipeline")
494
+ model_id = "facebook/opt-350m" # Use a smaller model as fallback
 
495
  pipe = pipeline(
496
  "text-generation",
497
  model=model_id,
 
503
  @property
504
  def _llm_type(self) -> str:
505
  return "local_llm"
 
506
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
507
+ # For this fallback, truncate prompt if it exceeds limits
508
+ reserved_gen = 128
509
+ max_total = 1024
510
+ max_prompt_tokens = max_total - reserved_gen
511
+ truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
512
+ generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
513
+ return generated
 
 
 
 
 
514
  @property
515
  def _identifying_params(self) -> dict:
516
+ return {"model": model_id, "max_length": 1024}
517
 
518
  debug_print("Local fallback pipeline created.")
519
  return LocalLLM()
520
+
521
  except Exception as e:
522
  debug_print(f"Error creating LLM pipeline: {str(e)}")
523
  # Return a dummy LLM that explains the error
 
526
  def _llm_type(self) -> str:
527
  return "error_llm"
528
 
529
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
530
  return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
531
 
532
  @property
 
535
 
536
  return ErrorLLM()
537
 
538
+
539
  def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
540
  debug_print(f"Updating chain with new model: {new_model_choice}")
541
  self.llm_choice = new_model_choice
 
605
  "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
606
  "question": RunnableLambda(self.extract_question)
607
  }) | self.capture_context
608
+ # Wrap the prompt template in a RunnableLambda
609
  self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
610
+ prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
611
  self.str_output_parser = StrOutputParser()
612
  debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
613
  self.llm = self.create_llm_pipeline()
 
620
  formatted += f"- **Generated using:** {self.llm_choice}\n"
621
  formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
622
  return formatted
623
+ self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
624
  debug_print("Elevated RAG chain successfully built and ready to use.")
625
 
626
+
627
  def get_current_context(self) -> str:
628
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
629
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
 
901
 
902
  with gr.TabItem("Submit Query"):
903
  with gr.Row():
904
+ # Add this line to define the query_model_dropdown
905
+ query_model_dropdown = gr.Dropdown(
906
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
907
+ value="🇺🇸 Remote Meta-Llama-3",
908
+ label="Query Model"
909
+ )
910
+
911
  query_input = gr.Textbox(
912
  label="Enter your query here",
913
  placeholder="Type your query",
 
998
  outputs=[load_response, load_context, model_output]
999
  )
1000
 
1001
+ # Also sync in the other direction
1002
+ query_model_dropdown.change(
1003
+ fn=sync_model_dropdown,
1004
+ inputs=query_model_dropdown,
1005
+ outputs=model_dropdown
1006
+ )
1007
+
1008
  submit_button.click(
1009
  submit_query_async,
1010
  inputs=[query_input, query_model_dropdown],
 
1042
  outputs=[reset_response, reset_context, reset_model]
1043
  )
1044
 
1045
+
1046
  model_dropdown.change(
1047
  fn=sync_model_dropdown,
1048
  inputs=model_dropdown,
1049
  outputs=query_model_dropdown
1050
  )
1051
 
 
 
 
 
 
 
 
1052
  # Add an event to refresh the job list on page load
1053
  app.load(
1054
  fn=refresh_job_list,