Spaces:
Sleeping
Sleeping
fix
Browse files- app.py +1 -2
- gemmademo/_chat.py +18 -6
app.py
CHANGED
|
@@ -6,8 +6,7 @@ def main():
|
|
| 6 |
model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
|
| 7 |
task_options = ["Question Answering", "Text Generation", "Code Completion"]
|
| 8 |
|
| 9 |
-
|
| 10 |
-
chat = GradioChat(prompt_manager=prompt_manager, model_options=model_options, task_options=task_options)
|
| 11 |
chat.run()
|
| 12 |
|
| 13 |
if __name__ == "__main__":
|
|
|
|
| 6 |
model_options = list(LlamaCppGemmaModel.AVAILABLE_MODELS.keys())
|
| 7 |
task_options = ["Question Answering", "Text Generation", "Code Completion"]
|
| 8 |
|
| 9 |
+
chat = GradioChat(model_options=model_options, task_options=task_options)
|
|
|
|
| 10 |
chat.run()
|
| 11 |
|
| 12 |
if __name__ == "__main__":
|
gemmademo/_chat.py
CHANGED
|
@@ -10,28 +10,40 @@ class GradioChat:
|
|
| 10 |
Features:
|
| 11 |
- A Gradio-based chatbot UI.
|
| 12 |
- Dynamically loads models based on user selection.
|
|
|
|
| 13 |
- Uses Gemma (llama.cpp) for generating responses.
|
| 14 |
-
- Formats user inputs before sending them to the model.
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
def __init__(self,
|
| 18 |
-
self.prompt_manager = prompt_manager
|
| 19 |
self.model_options = model_options
|
| 20 |
self.task_options = task_options
|
| 21 |
self.current_model_name = "gemma-2b-it" # Default model
|
| 22 |
self.model = self._load_model(self.current_model_name)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def _load_model(self, model_name: str):
|
| 25 |
"""Loads the model dynamically when switching models."""
|
| 26 |
return LlamaCppGemmaModel(name=model_name).load_model()
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def _chat(self):
|
| 29 |
def chat_fn(message, history, selected_model, selected_task):
|
|
|
|
| 30 |
if selected_model != self.current_model_name:
|
| 31 |
self.current_model_name = selected_model
|
| 32 |
-
self.model = self._load_model(selected_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
|
|
|
| 35 |
response = self.model.generate_response(prompt)
|
| 36 |
return response
|
| 37 |
|
|
@@ -40,7 +52,7 @@ class GradioChat:
|
|
| 40 |
textbox=gr.Textbox(placeholder="Ask me something...", container=False),
|
| 41 |
additional_inputs=[
|
| 42 |
gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
|
| 43 |
-
gr.Dropdown(choices=self.task_options, value=
|
| 44 |
],
|
| 45 |
)
|
| 46 |
chat_interface.launch()
|
|
|
|
| 10 |
Features:
|
| 11 |
- A Gradio-based chatbot UI.
|
| 12 |
- Dynamically loads models based on user selection.
|
| 13 |
+
- Dynamically updates tasks using PromptManager.
|
| 14 |
- Uses Gemma (llama.cpp) for generating responses.
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
def __init__(self, model_options: list[str], task_options: list[str]):
|
|
|
|
| 18 |
self.model_options = model_options
|
| 19 |
self.task_options = task_options
|
| 20 |
self.current_model_name = "gemma-2b-it" # Default model
|
| 21 |
self.model = self._load_model(self.current_model_name)
|
| 22 |
+
self.current_task_name = "Question Answering" # Default task
|
| 23 |
+
self.prompt_manager = self._load_task(self.current_task_name)
|
| 24 |
|
| 25 |
def _load_model(self, model_name: str):
|
| 26 |
"""Loads the model dynamically when switching models."""
|
| 27 |
return LlamaCppGemmaModel(name=model_name).load_model()
|
| 28 |
|
| 29 |
+
def _load_task(self, task_name: str):
|
| 30 |
+
"""Loads the task dynamically when switching tasks."""
|
| 31 |
+
return PromptManager(task=task_name)
|
| 32 |
+
|
| 33 |
def _chat(self):
|
| 34 |
def chat_fn(message, history, selected_model, selected_task):
|
| 35 |
+
# Reload model if changed
|
| 36 |
if selected_model != self.current_model_name:
|
| 37 |
self.current_model_name = selected_model
|
| 38 |
+
self.model = self._load_model(selected_model)
|
| 39 |
+
|
| 40 |
+
# Reload task if changed
|
| 41 |
+
if selected_task != self.current_task_name:
|
| 42 |
+
self.current_task_name = selected_task
|
| 43 |
+
self.prompt_manager = self._load_task(selected_task)
|
| 44 |
|
| 45 |
+
# Generate response using updated model & prompt manager
|
| 46 |
+
prompt = self.prompt_manager.get_prompt(user_input=message)
|
| 47 |
response = self.model.generate_response(prompt)
|
| 48 |
return response
|
| 49 |
|
|
|
|
| 52 |
textbox=gr.Textbox(placeholder="Ask me something...", container=False),
|
| 53 |
additional_inputs=[
|
| 54 |
gr.Dropdown(choices=self.model_options, value=self.current_model_name, label="Select Gemma Model"),
|
| 55 |
+
gr.Dropdown(choices=self.task_options, value=self.current_task_name, label="Select Task"),
|
| 56 |
],
|
| 57 |
)
|
| 58 |
chat_interface.launch()
|