aadya1762 commited on
Commit
581c860
·
1 Parent(s): 5f43529

bug fixes

Browse files
Files changed (2) hide show
  1. gemmademo/_chat.py +11 -3
  2. gemmademo/_model.py +0 -1
gemmademo/_chat.py CHANGED
@@ -17,13 +17,16 @@ class GradioChat:
17
  def __init__(self, model_options: list[str], task_options: list[str]):
18
  self.model_options = model_options
19
  self.task_options = task_options
20
- self.current_model_name = "gemma-3b" # Default model
21
- self.current_task_name = "Question Answering" # Default task
 
22
 
23
  # Load model lazily on first use instead of at initialization
24
  self.model = None
25
  self.prompt_manager = self._load_task(self.current_task_name)
26
- self.models_cache = {} # Cache for loaded models
 
 
27
 
28
  def _load_model(self, model_name: str):
29
  """Loads the model dynamically when switching models, with caching."""
@@ -48,11 +51,16 @@ class GradioChat:
48
  if selected_model != self.current_model_name:
49
  self.current_model_name = selected_model
50
  self.model = self._load_model(selected_model)
 
 
51
 
52
  # Reload task if changed
53
  if selected_task != self.current_task_name:
54
  self.current_task_name = selected_task
55
  self.prompt_manager = self._load_task(selected_task)
 
 
 
56
 
57
  # Generate response using updated model & prompt manager
58
  prompt = self.prompt_manager.get_prompt(user_input=message)
 
17
  def __init__(self, model_options: list[str], task_options: list[str]):
18
  self.model_options = model_options
19
  self.task_options = task_options
20
+
21
+ self.current_model_name = "gemma-3b"
22
+ self.current_task_name = "Question Answering"
23
 
24
  # Load model lazily on first use instead of at initialization
25
  self.model = None
26
  self.prompt_manager = self._load_task(self.current_task_name)
27
+
28
+ # Cache.
29
+ self.models_cache = {}
30
 
31
  def _load_model(self, model_name: str):
32
  """Loads the model dynamically when switching models, with caching."""
 
51
  if selected_model != self.current_model_name:
52
  self.current_model_name = selected_model
53
  self.model = self._load_model(selected_model)
54
+ # Clear message history when model changes
55
+ self.model.messages = []
56
 
57
  # Reload task if changed
58
  if selected_task != self.current_task_name:
59
  self.current_task_name = selected_task
60
  self.prompt_manager = self._load_task(selected_task)
61
+ # Clear message history when task changes
62
+ if self.model:
63
+ self.model.messages = []
64
 
65
  # Generate response using updated model & prompt manager
66
  prompt = self.prompt_manager.get_prompt(user_input=message)
gemmademo/_model.py CHANGED
@@ -120,7 +120,6 @@ class LlamaCppGemmaModel:
120
  if downloaded_path != model_path:
121
  os.rename(downloaded_path, model_path)
122
 
123
- # Use optimized thread settings (fewer threads often works better)
124
  _threads = min(2, os.cpu_count() or 1)
125
 
126
  self.model = Llama(
 
120
  if downloaded_path != model_path:
121
  os.rename(downloaded_path, model_path)
122
 
 
123
  _threads = min(2, os.cpu_count() or 1)
124
 
125
  self.model = Llama(