aadya1762 commited on
Commit
1719ee5
·
1 Parent(s): 6aec7fd
Files changed (2) hide show
  1. app.py +1 -3
  2. gemmademo/_chat.py +17 -8
app.py CHANGED
@@ -6,10 +6,8 @@ def main():
6
  model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
7
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
8
 
9
- model = LlamaCppGemmaModel(name="gemma-2b-it")
10
- model.load_model()
11
  prompt_manager = PromptManager(task="Question Answering")
12
- chat = GradioChat(model=model, prompt_manager=prompt_manager, model_options=model_options, task_options=task_options)
13
  chat.run()
14
 
15
  if __name__ == "__main__":
 
6
  model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
7
  task_options = ["Question Answering", "Text Generation", "Code Completion"]
8
 
 
 
9
  prompt_manager = PromptManager(task="Question Answering")
10
+ chat = GradioChat(prompt_manager=prompt_manager, model_options=model_options, task_options=task_options)
11
  chat.run()
12
 
13
  if __name__ == "__main__":
gemmademo/_chat.py CHANGED
@@ -9,22 +9,28 @@ class GradioChat:
9
 
10
  Features:
11
  - A Gradio-based chatbot UI.
12
- - Maintains chat history automatically.
13
- - Uses Gemma (Hugging Face) model for generating responses.
14
  - Formats user inputs before sending them to the model.
15
  """
16
 
17
- def __init__(self, model: LlamaCppGemmaModel, prompt_manager: PromptManager, model_options: list[str], task_options: list[str]):
18
- self.model = model
19
  self.prompt_manager = prompt_manager
20
  self.model_options = model_options
21
  self.task_options = task_options
 
 
22
 
23
- def run(self):
24
- self._chat()
 
25
 
26
  def _chat(self):
27
- def chat_fn(history, message):
 
 
 
 
28
  prompt = self.prompt_manager.get_prompt(user_input=message)
29
  response = self.model.generate_response(prompt)
30
  return response
@@ -33,8 +39,11 @@ class GradioChat:
33
  chat_fn,
34
  textbox=gr.Textbox(placeholder="What is up?", container=False),
35
  additional_inputs=[
36
- gr.Dropdown(choices=self.model_options, value="gemma-2b-it", label="Select Gemma Model"),
37
  gr.Dropdown(choices=self.task_options, value="Question Answering", label="Select Task"),
38
  ],
39
  )
40
  chat_interface.launch()
 
 
 
 
9
 
10
  Features:
11
  - A Gradio-based chatbot UI.
12
+ - Dynamically loads models based on user selection.
13
+ - Uses Gemma (llama.cpp) for generating responses.
14
  - Formats user inputs before sending them to the model.
15
  """
16
 
17
+ def __init__(self, prompt_manager: PromptManager, model_options: list[str], task_options: list[str]):
 
18
  self.prompt_manager = prompt_manager
19
  self.model_options = model_options
20
  self.task_options = task_options
21
+ self.current_model_name = "gemma-2b-it" # Default model
22
+ self.model = self._load_model(self.current_model_name)
23
 
24
+ def _load_model(self, model_name: str):
25
+ """Loads the model dynamically when switching models."""
26
+ return LlamaCppGemmaModel(name=model_name).load_model()
27
 
28
  def _chat(self):
29
+ def chat_fn(history, message, selected_model):
30
+ if selected_model != self.current_model_name:
31
+ self.current_model_name = selected_model
32
+ self.model = self._load_model(selected_model) # Reload model when changed
33
+
34
  prompt = self.prompt_manager.get_prompt(user_input=message)
35
  response = self.model.generate_response(prompt)
36
  return response
 
39
  chat_fn,
40
  textbox=gr.Textbox(placeholder="What is up?", container=False),
41
  additional_inputs=[
42
+ gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
43
  gr.Dropdown(choices=self.task_options, value="Question Answering", label="Select Task"),
44
  ],
45
  )
46
  chat_interface.launch()
47
+
48
+ def run(self):
49
+ self._chat()