Spaces:
Sleeping
Sleeping
bug fixes
Browse files- gemmademo/_chat.py +11 -3
- gemmademo/_model.py +0 -1
gemmademo/_chat.py
CHANGED
|
@@ -17,13 +17,16 @@ class GradioChat:
|
|
| 17 |
def __init__(self, model_options: list[str], task_options: list[str]):
|
| 18 |
self.model_options = model_options
|
| 19 |
self.task_options = task_options
|
| 20 |
-
|
| 21 |
-
self.
|
|
|
|
| 22 |
|
| 23 |
# Load model lazily on first use instead of at initialization
|
| 24 |
self.model = None
|
| 25 |
self.prompt_manager = self._load_task(self.current_task_name)
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def _load_model(self, model_name: str):
|
| 29 |
"""Loads the model dynamically when switching models, with caching."""
|
|
@@ -48,11 +51,16 @@ class GradioChat:
|
|
| 48 |
if selected_model != self.current_model_name:
|
| 49 |
self.current_model_name = selected_model
|
| 50 |
self.model = self._load_model(selected_model)
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Reload task if changed
|
| 53 |
if selected_task != self.current_task_name:
|
| 54 |
self.current_task_name = selected_task
|
| 55 |
self.prompt_manager = self._load_task(selected_task)
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Generate response using updated model & prompt manager
|
| 58 |
prompt = self.prompt_manager.get_prompt(user_input=message)
|
|
|
|
| 17 |
def __init__(self, model_options: list[str], task_options: list[str]):
|
| 18 |
self.model_options = model_options
|
| 19 |
self.task_options = task_options
|
| 20 |
+
|
| 21 |
+
self.current_model_name = "gemma-3b"
|
| 22 |
+
self.current_task_name = "Question Answering"
|
| 23 |
|
| 24 |
# Load model lazily on first use instead of at initialization
|
| 25 |
self.model = None
|
| 26 |
self.prompt_manager = self._load_task(self.current_task_name)
|
| 27 |
+
|
| 28 |
+
# Cache.
|
| 29 |
+
self.models_cache = {}
|
| 30 |
|
| 31 |
def _load_model(self, model_name: str):
|
| 32 |
"""Loads the model dynamically when switching models, with caching."""
|
|
|
|
| 51 |
if selected_model != self.current_model_name:
|
| 52 |
self.current_model_name = selected_model
|
| 53 |
self.model = self._load_model(selected_model)
|
| 54 |
+
# Clear message history when model changes
|
| 55 |
+
self.model.messages = []
|
| 56 |
|
| 57 |
# Reload task if changed
|
| 58 |
if selected_task != self.current_task_name:
|
| 59 |
self.current_task_name = selected_task
|
| 60 |
self.prompt_manager = self._load_task(selected_task)
|
| 61 |
+
# Clear message history when task changes
|
| 62 |
+
if self.model:
|
| 63 |
+
self.model.messages = []
|
| 64 |
|
| 65 |
# Generate response using updated model & prompt manager
|
| 66 |
prompt = self.prompt_manager.get_prompt(user_input=message)
|
gemmademo/_model.py
CHANGED
|
@@ -120,7 +120,6 @@ class LlamaCppGemmaModel:
|
|
| 120 |
if downloaded_path != model_path:
|
| 121 |
os.rename(downloaded_path, model_path)
|
| 122 |
|
| 123 |
-
# Use optimized thread settings (fewer threads often works better)
|
| 124 |
_threads = min(2, os.cpu_count() or 1)
|
| 125 |
|
| 126 |
self.model = Llama(
|
|
|
|
| 120 |
if downloaded_path != model_path:
|
| 121 |
os.rename(downloaded_path, model_path)
|
| 122 |
|
|
|
|
| 123 |
_threads = min(2, os.cpu_count() or 1)
|
| 124 |
|
| 125 |
self.model = Llama(
|