aadya1762 commited on
Commit
5f43529
·
1 Parent(s): 0304bfe
Files changed (2) hide show
  1. gemmademo/_chat.py +15 -4
  2. gemmademo/_model.py +42 -8
gemmademo/_chat.py CHANGED
@@ -20,12 +20,19 @@ class GradioChat:
20
  self.current_model_name = "gemma-3b" # Default model
21
  self.current_task_name = "Question Answering" # Default task
22
 
23
- self.model = self._load_model(self.current_model_name)
 
24
  self.prompt_manager = self._load_task(self.current_task_name)
 
25
 
26
  def _load_model(self, model_name: str):
27
- """Loads the model dynamically when switching models."""
28
- return LlamaCppGemmaModel(name=model_name).load_model()
 
 
 
 
 
29
 
30
  def _load_task(self, task_name: str):
31
  """Loads the task dynamically when switching tasks."""
@@ -33,7 +40,11 @@ class GradioChat:
33
 
34
  def _chat(self):
35
  def chat_fn(message, history, selected_model, selected_task):
36
- # Reload model if changed
 
 
 
 
37
  if selected_model != self.current_model_name:
38
  self.current_model_name = selected_model
39
  self.model = self._load_model(selected_model)
 
20
  self.current_model_name = "gemma-3b" # Default model
21
  self.current_task_name = "Question Answering" # Default task
22
 
23
+ # Load model lazily on first use instead of at initialization
24
+ self.model = None
25
  self.prompt_manager = self._load_task(self.current_task_name)
26
+ self.models_cache = {} # Cache for loaded models
27
 
28
  def _load_model(self, model_name: str):
29
+ """Loads the model dynamically when switching models, with caching."""
30
+ if model_name in self.models_cache:
31
+ return self.models_cache[model_name]
32
+
33
+ model = LlamaCppGemmaModel(name=model_name).load_model()
34
+ self.models_cache[model_name] = model
35
+ return model
36
 
37
  def _load_task(self, task_name: str):
38
  """Loads the task dynamically when switching tasks."""
 
40
 
41
  def _chat(self):
42
  def chat_fn(message, history, selected_model, selected_task):
43
+ # Lazy load model on first use
44
+ if self.model is None:
45
+ self.model = self._load_model(self.current_model_name)
46
+
47
+ # Reload model if changed, using cache when possible
48
  if selected_model != self.current_model_name:
49
  self.current_model_name = selected_model
50
  self.model = self._load_model(selected_model)
gemmademo/_model.py CHANGED
@@ -19,13 +19,23 @@ class LlamaCppGemmaModel:
19
  All models will be stored in the "models/" directory.
20
  """
21
 
 
 
 
22
  AVAILABLE_MODELS: Dict[str, Dict] = {
23
  "gemma-3b": {
24
- "model_path": "models/gemma3.gguf",
25
- "repo_id": "unsloth/gemma-3-1b-it-GGUF", # update to the actual repo id
26
- "filename": "gemma-3-1b-it-Q3_K_M.gguf",
27
- "description": "3B parameters, base model",
28
- "type": "base",
 
 
 
 
 
 
 
29
  },
30
  "gemma-2b": {
31
  "model_path": "models/gemma-2b.gguf",
@@ -71,11 +81,18 @@ class LlamaCppGemmaModel:
71
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
72
  """
73
  Load the model. If the model file does not exist, it will be downloaded.
 
74
 
75
  Args:
76
  n_ctx (int): Context window size.
77
  n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
78
  """
 
 
 
 
 
 
79
  model_info = self.AVAILABLE_MODELS.get(self.name)
80
  if not model_info:
81
  raise ValueError(f"Model {self.name} is not available.")
@@ -103,7 +120,8 @@ class LlamaCppGemmaModel:
103
  if downloaded_path != model_path:
104
  os.rename(downloaded_path, model_path)
105
 
106
- _threads = os.cpu_count()
 
107
 
108
  self.model = Llama(
109
  model_path=model_path,
@@ -112,19 +130,32 @@ class LlamaCppGemmaModel:
112
  n_ctx=n_ctx,
113
  n_gpu_layers=n_gpu_layers,
114
  n_batch=8,
 
115
  )
 
 
 
116
  return self
117
 
118
  def generate_response(
119
- self, prompt: str, max_tokens: int = 512, temperature: float = 0.1
 
 
 
 
 
 
120
  ):
121
  """
122
- Generate a response using the llama.cpp model.
123
 
124
  Args:
125
  prompt (str): Input prompt text.
126
  max_tokens (int): Maximum number of tokens to generate.
127
  temperature (float): Sampling temperature (higher = more creative).
 
 
 
128
 
129
  Yields:
130
  str: Generated response text as a stream.
@@ -138,6 +169,9 @@ class LlamaCppGemmaModel:
138
  messages=self.messages,
139
  max_tokens=max_tokens,
140
  temperature=temperature,
 
 
 
141
  stream=True,
142
  )
143
  self.messages.append({"role": "assistant", "content": ""})
 
19
  All models will be stored in the "models/" directory.
20
  """
21
 
22
+ # Class variable to cache loaded models
23
+ _model_cache = {}
24
+
25
  AVAILABLE_MODELS: Dict[str, Dict] = {
26
  "gemma-3b": {
27
+ "model_path": "models/gemma-3-1b-it-Q5_K_M.gguf",
28
+ "repo_id": "bartowski/google_gemma-3-1b-it-GGUF", # Updated repo
29
+ "filename": "google_gemma-3-1b-it-Q5_K_M.gguf", # Better quantization
30
+ "description": "3B parameters, instruction-tuned (Q5_K_M)",
31
+ "type": "instruct",
32
+ },
33
+ "gemma-3b-q6": {
34
+ "model_path": "models/gemma-3-1b-it-Q6_K.gguf",
35
+ "repo_id": "bartowski/google_gemma-3-1b-it-GGUF", # Updated repo
36
+ "filename": "google_gemma-3-1b-it-Q6_K.gguf", # Higher quality quantization
37
+ "description": "3B parameters, instruction-tuned (Q6_K)",
38
+ "type": "instruct",
39
  },
40
  "gemma-2b": {
41
  "model_path": "models/gemma-2b.gguf",
 
81
  def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
82
  """
83
  Load the model. If the model file does not exist, it will be downloaded.
84
+ Uses caching to avoid reloading models unnecessarily.
85
 
86
  Args:
87
  n_ctx (int): Context window size.
88
  n_gpu_layers (int): Number of layers to offload to GPU (if supported; 0 for CPU-only).
89
  """
90
+ # Check if model is already loaded in cache
91
+ cache_key = f"{self.name}_{n_ctx}_{n_gpu_layers}"
92
+ if cache_key in LlamaCppGemmaModel._model_cache:
93
+ self.model = LlamaCppGemmaModel._model_cache[cache_key]
94
+ return self
95
+
96
  model_info = self.AVAILABLE_MODELS.get(self.name)
97
  if not model_info:
98
  raise ValueError(f"Model {self.name} is not available.")
 
120
  if downloaded_path != model_path:
121
  os.rename(downloaded_path, model_path)
122
 
123
+ # Use optimized thread settings (fewer threads often works better)
124
+ _threads = min(2, os.cpu_count() or 1)
125
 
126
  self.model = Llama(
127
  model_path=model_path,
 
130
  n_ctx=n_ctx,
131
  n_gpu_layers=n_gpu_layers,
132
  n_batch=8,
133
+ verbose=False, # Disable verbose output for better performance
134
  )
135
+
136
+ # Cache the model for future use
137
+ LlamaCppGemmaModel._model_cache[cache_key] = self.model
138
  return self
139
 
140
  def generate_response(
141
+ self,
142
+ prompt: str,
143
+ max_tokens: int = 512,
144
+ temperature: float = 0.7,
145
+ top_p: float = 0.95,
146
+ top_k: int = 40,
147
+ repeat_penalty: float = 1.1,
148
  ):
149
  """
150
+ Generate a response using the llama.cpp model with optimized parameters.
151
 
152
  Args:
153
  prompt (str): Input prompt text.
154
  max_tokens (int): Maximum number of tokens to generate.
155
  temperature (float): Sampling temperature (higher = more creative).
156
+ top_p (float): Nucleus sampling threshold.
157
+ top_k (int): Limit vocabulary choices to top K tokens.
158
+ repeat_penalty (float): Penalize repeated words.
159
 
160
  Yields:
161
  str: Generated response text as a stream.
 
169
  messages=self.messages,
170
  max_tokens=max_tokens,
171
  temperature=temperature,
172
+ top_p=top_p,
173
+ top_k=top_k,
174
+ repeat_penalty=repeat_penalty,
175
  stream=True,
176
  )
177
  self.messages.append({"role": "assistant", "content": ""})