aadya1762 commited on
Commit
e4ef2eb
·
1 Parent(s): 4ffa911
Files changed (2) hide show
  1. app.py +1 -2
  2. gemmademo/_chat.py +18 -6
app.py CHANGED
@@ -6,8 +6,7 @@ def main():
6
  model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
7
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
8
 
9
- prompt_manager = PromptManager(task="Question Answering")
10
- chat = GradioChat(prompt_manager=prompt_manager, model_options=model_options, task_options=task_options)
11
  chat.run()
12
 
13
  if __name__ == "__main__":
 
6
  model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
7
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
8
 
9
+ chat = GradioChat(model_options=model_options, task_options=task_options)
 
10
  chat.run()
11
 
12
  if __name__ == "__main__":
gemmademo/_chat.py CHANGED
@@ -10,28 +10,40 @@ class GradioChat:
10
  Features:
11
  - A Gradio-based chatbot UI.
12
  - Dynamically loads models based on user selection.
 
13
  - Uses Gemma (llama.cpp) for generating responses.
14
- - Formats user inputs before sending them to the model.
15
  """
16
 
17
- def __init__(self, prompt_manager: PromptManager, model_options: list[str], task_options: list[str]):
18
- self.prompt_manager = prompt_manager
19
  self.model_options = model_options
20
  self.task_options = task_options
21
  self.current_model_name = "gemma-2b-it" # Default model
22
  self.model = self._load_model(self.current_model_name)
 
 
23
 
24
  def _load_model(self, model_name: str):
25
  """Loads the model dynamically when switching models."""
26
  return LlamaCppGemmaModel(name=model_name).load_model()
27
 
 
 
 
 
28
  def _chat(self):
29
  def chat_fn(message, history, selected_model, selected_task):
 
30
  if selected_model != self.current_model_name:
31
  self.current_model_name = selected_model
32
- self.model = self._load_model(selected_model) # Reload model when changed
 
 
 
 
 
33
 
34
- prompt = self.prompt_manager.get_prompt(user_input=message, task=selected_task)
 
35
  response = self.model.generate_response(prompt)
36
  return response
37
 
@@ -40,7 +52,7 @@ class GradioChat:
40
  textbox=gr.Textbox(placeholder="Ask me something...", container=False),
41
  additional_inputs=[
42
  gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
43
- gr.Dropdown(choices=self.task_options, value="Question Answering", label="Select Task"),
44
  ],
45
  )
46
  chat_interface.launch()
 
10
  Features:
11
  - A Gradio-based chatbot UI.
12
  - Dynamically loads models based on user selection.
13
+ - Dynamically updates tasks using PromptManager.
14
  - Uses Gemma (llama.cpp) for generating responses.
 
15
  """
16
 
17
+ def __init__(self, model_options: list[str], task_options: list[str]):
 
18
  self.model_options = model_options
19
  self.task_options = task_options
20
  self.current_model_name = "gemma-2b-it" # Default model
21
  self.model = self._load_model(self.current_model_name)
22
+ self.current_task_name = "Question Answering" # Default task
23
+ self.prompt_manager = self._load_task(self.current_task_name)
24
 
25
  def _load_model(self, model_name: str):
26
  """Loads the model dynamically when switching models."""
27
  return LlamaCppGemmaModel(name=model_name).load_model()
28
 
29
+ def _load_task(self, task_name: str):
30
+ """Loads the task dynamically when switching tasks."""
31
+ return PromptManager(task=task_name)
32
+
33
  def _chat(self):
34
  def chat_fn(message, history, selected_model, selected_task):
35
+ # Reload model if changed
36
  if selected_model != self.current_model_name:
37
  self.current_model_name = selected_model
38
+ self.model = self._load_model(selected_model)
39
+
40
+ # Reload task if changed
41
+ if selected_task != self.current_task_name:
42
+ self.current_task_name = selected_task
43
+ self.prompt_manager = self._load_task(selected_task)
44
 
45
+ # Generate response using updated model & prompt manager
46
+ prompt = self.prompt_manager.get_prompt(user_input=message)
47
  response = self.model.generate_response(prompt)
48
  return response
49
 
 
52
  textbox=gr.Textbox(placeholder="Ask me something...", container=False),
53
  additional_inputs=[
54
  gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
55
+ gr.Dropdown(choices=self.task_options, value=self.current_task_name, label="Select Task"),
56
  ],
57
  )
58
  chat_interface.launch()