aadya1762 commited on
Commit
e9e9e0c
·
1 Parent(s): 9842827

add sliders

Browse files
Files changed (3) hide show
  1. gemmademo/_chat.py +72 -5
  2. gemmademo/_model.py +18 -11
  3. gemmademo/_prompts.py +8 -38
gemmademo/_chat.py CHANGED
@@ -33,13 +33,18 @@ class GradioChat:
33
  if model_name in self.models_cache:
34
  return self.models_cache[model_name]
35
 
36
- model = LlamaCppGemmaModel(name=model_name).load_model()
 
 
37
  self.models_cache[model_name] = model
 
38
  return model
39
 
40
  def _load_task(self, task_name: str):
41
  """Loads the task dynamically when switching tasks."""
42
- return PromptManager(task=task_name)
 
 
43
 
44
  def _chat(self):
45
  def chat_fn(message, history, selected_model, selected_task):
@@ -49,18 +54,22 @@ class GradioChat:
49
 
50
  # Reload model if changed, using cache when possible
51
  if selected_model != self.current_model_name:
52
- self.current_model_name = selected_model
53
  self.model = self._load_model(selected_model)
54
  # Clear message history when model changes
55
  self.model.messages = []
56
 
57
  # Reload task if changed
58
  if selected_task != self.current_task_name:
59
- self.current_task_name = selected_task
60
- self.prompt_manager = self._load_task(selected_task)
61
  # Clear message history when task changes
62
  if self.model:
63
  self.model.messages = []
 
 
 
 
 
 
64
 
65
  # Generate response using updated model & prompt manager
66
  prompt = self.prompt_manager.get_prompt(user_input=message)
@@ -137,6 +146,64 @@ class GradioChat:
137
  task_dropdown.change(
138
  _update_examples, task_dropdown, examples_list.dataset
139
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  demo.launch()
142
 
 
33
  if model_name in self.models_cache:
34
  return self.models_cache[model_name]
35
 
36
+ model = LlamaCppGemmaModel(name=model_name).load_model(
37
+ system_prompt=self.prompt_manager.get_system_prompt()
38
+ )
39
  self.models_cache[model_name] = model
40
+ self.current_model_name = model_name
41
  return model
42
 
43
  def _load_task(self, task_name: str):
44
  """Loads the task dynamically when switching tasks."""
45
+ self.current_task_name = task_name
46
+ self.prompt_manager = PromptManager(task=task_name)
47
+ return
48
 
49
  def _chat(self):
50
  def chat_fn(message, history, selected_model, selected_task):
 
54
 
55
  # Reload model if changed, using cache when possible
56
  if selected_model != self.current_model_name:
 
57
  self.model = self._load_model(selected_model)
58
  # Clear message history when model changes
59
  self.model.messages = []
60
 
61
  # Reload task if changed
62
  if selected_task != self.current_task_name:
63
+ self._load_task(selected_task)
 
64
  # Clear message history when task changes
65
  if self.model:
66
  self.model.messages = []
67
+ self.model.messages = [
68
+ {
69
+ "role": "system",
70
+ "content": self.prompt_manager.get_system_prompt(),
71
+ }
72
+ ]
73
 
74
  # Generate response using updated model & prompt manager
75
  prompt = self.prompt_manager.get_prompt(user_input=message)
 
146
  task_dropdown.change(
147
  _update_examples, task_dropdown, examples_list.dataset
148
  )
149
+ temperature_slider = gr.Slider(
150
+ minimum=0.1, maximum=2, value=1.0, label="Temperature"
151
+ )
152
+ gr.Markdown(
153
+ "**Temperature:** Controls the randomness of the model's output. Lower values make the output more deterministic."
154
+ )
155
+ temperature_slider.change(
156
+ fn=lambda temp: setattr(self.model, "temperature", temp),
157
+ inputs=temperature_slider,
158
+ )
159
+
160
+ top_p_slider = gr.Slider(
161
+ minimum=0.1, maximum=1.0, value=0.9, label="Top P"
162
+ )
163
+ gr.Markdown(
164
+ "**Top P:** Limits the sampling to a subset of the most probable tokens. Lower values make the output more focused."
165
+ )
166
+ top_p_slider.change(
167
+ fn=lambda top_p: setattr(self.model, "top_p", top_p),
168
+ inputs=top_p_slider,
169
+ )
170
+
171
+ top_k_slider = gr.Slider(
172
+ minimum=1, maximum=100, value=50, label="Top K"
173
+ )
174
+ gr.Markdown(
175
+ "**Top K:** Limits the sampling to the top K most probable tokens. Lower values make the output more focused."
176
+ )
177
+ top_k_slider.change(
178
+ fn=lambda top_k: setattr(self.model, "top_k", top_k),
179
+ inputs=top_k_slider,
180
+ )
181
+
182
+ repetition_penalty_slider = gr.Slider(
183
+ minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty"
184
+ )
185
+ gr.Markdown(
186
+ "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output."
187
+ )
188
+ repetition_penalty_slider.change(
189
+ fn=lambda penalty: setattr(
190
+ self.model, "repetition_penalty", penalty
191
+ ),
192
+ inputs=repetition_penalty_slider,
193
+ )
194
+
195
+ max_tokens_slider = gr.Slider(
196
+ minimum=512, maximum=2048, value=1024, label="Max Tokens"
197
+ )
198
+ gr.Markdown(
199
+ "**Max Tokens:** Sets the maximum number of tokens the model can generate in a single response."
200
+ )
201
+ max_tokens_slider.change(
202
+ fn=lambda max_tokens: setattr(
203
+ self.model, "max_tokens", max_tokens
204
+ ),
205
+ inputs=max_tokens_slider,
206
+ )
207
 
208
  demo.launch()
209
 
gemmademo/_model.py CHANGED
@@ -50,7 +50,14 @@ class LlamaCppGemmaModel:
50
  self.model = None # Instance of Llama from llama.cpp
51
  self.messages = []
52
 
53
- def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
 
 
 
 
 
 
 
54
  """
55
  Load the model. If the model file does not exist, it will be downloaded.
56
  Uses caching to avoid reloading models unnecessarily.
@@ -94,6 +101,8 @@ class LlamaCppGemmaModel:
94
 
95
  _threads = min(2, os.cpu_count() or 1)
96
 
 
 
97
  self.model = Llama(
98
  model_path=model_path,
99
  n_threads=_threads,
@@ -102,8 +111,11 @@ class LlamaCppGemmaModel:
102
  n_gpu_layers=n_gpu_layers,
103
  n_batch=8,
104
  verbose=False,
 
105
  )
106
 
 
 
107
  # Cache the model for future use
108
  LlamaCppGemmaModel._model_cache[cache_key] = self.model
109
  return self
@@ -111,11 +123,6 @@ class LlamaCppGemmaModel:
111
  def generate_response(
112
  self,
113
  prompt: str,
114
- max_tokens: int = 512,
115
- temperature: float = 0.7,
116
- top_p: float = 0.95,
117
- top_k: int = 40,
118
- repeat_penalty: float = 1.1,
119
  ):
120
  """
121
  Generate a response using the llama.cpp model with optimized parameters.
@@ -138,11 +145,11 @@ class LlamaCppGemmaModel:
138
 
139
  response_stream = self.model.create_chat_completion(
140
  messages=self.messages,
141
- max_tokens=max_tokens,
142
- temperature=temperature,
143
- top_p=top_p,
144
- top_k=top_k,
145
- repeat_penalty=repeat_penalty,
146
  stream=True,
147
  )
148
  self.messages.append({"role": "assistant", "content": ""})
 
50
  self.model = None # Instance of Llama from llama.cpp
51
  self.messages = []
52
 
53
+ # Model response generation attributes
54
+ self.max_tokens = (512,)
55
+ self.temperature = (0.7,)
56
+ self.top_p = (0.95,)
57
+ self.top_k = (40,)
58
+ self.repeat_penalty = (1.1,)
59
+
60
+ def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0, system_prompt=""):
61
  """
62
  Load the model. If the model file does not exist, it will be downloaded.
63
  Uses caching to avoid reloading models unnecessarily.
 
101
 
102
  _threads = min(2, os.cpu_count() or 1)
103
 
104
+ _sys_prompt = {"role": "system", "content": system_prompt}
105
+
106
  self.model = Llama(
107
  model_path=model_path,
108
  n_threads=_threads,
 
111
  n_gpu_layers=n_gpu_layers,
112
  n_batch=8,
113
  verbose=False,
114
+ chat_format="chatml",
115
  )
116
 
117
+ self.messages.append(_sys_prompt)
118
+
119
  # Cache the model for future use
120
  LlamaCppGemmaModel._model_cache[cache_key] = self.model
121
  return self
 
123
  def generate_response(
124
  self,
125
  prompt: str,
 
 
 
 
 
126
  ):
127
  """
128
  Generate a response using the llama.cpp model with optimized parameters.
 
145
 
146
  response_stream = self.model.create_chat_completion(
147
  messages=self.messages,
148
+ max_tokens=self.max_tokens,
149
+ temperature=self.temperature,
150
+ top_p=self.top_p,
151
+ top_k=self.top_k,
152
+ repeat_penalty=self.repeat_penalty,
153
  stream=True,
154
  )
155
  self.messages.append({"role": "assistant", "content": ""})
gemmademo/_prompts.py CHANGED
@@ -6,48 +6,18 @@ class PromptManager:
6
  various tasks such as Question Answering, Text Generation, and Code Completion.
7
  It raises a ValueError if an unsupported task is specified.
8
  """
 
9
  def __init__(self, task):
10
  self.task = task
11
 
12
  def get_prompt(self, user_input):
 
 
 
 
13
  if self.task == "Question Answering":
14
- return self.get_question_answering_prompt(user_input)
15
  elif self.task == "Text Generation":
16
- return self.get_text_generation_prompt(user_input)
17
  elif self.task == "Code Completion":
18
- return self.get_code_completion_prompt(user_input)
19
- else:
20
- raise ValueError(f"Task {self.task} not supported")
21
-
22
- def get_question_answering_prompt(self, user_input):
23
- """
24
- Format user input for question answering task
25
- """
26
- prompt = f"""You are a helpful AI assistant. Answer the following question accurately and concisely.
27
- Only answer the question, do not provide any other information.
28
-
29
- Question: {user_input}
30
-
31
- Answer:"""
32
- return prompt
33
-
34
- def get_text_generation_prompt(self, user_input):
35
- """
36
- Format user input for text generation task
37
- """
38
- prompt = f"""Continue the following text in a coherent and engaging way:
39
- Only continue the text, do not provide any other information.
40
-
41
- {user_input}
42
-
43
- Continuation:"""
44
- return prompt
45
-
46
- def get_code_completion_prompt(self, user_input):
47
- """
48
- Format user input for code completion task
49
- """
50
- prompt = f"""Complete the following code snippet directly with proper syntax and
51
- without explanations or extra text:
52
- {user_input}"""
53
- return prompt
 
6
  various tasks such as Question Answering, Text Generation, and Code Completion.
7
  It raises a ValueError if an unsupported task is specified.
8
  """
9
+
10
  def __init__(self, task):
11
  self.task = task
12
 
13
  def get_prompt(self, user_input):
14
+ return user_input
15
+
16
+ def get_system_prompt(self):
17
+ """Returns the system prompt based on the specified task."""
18
  if self.task == "Question Answering":
19
+ return "You are a helpful AI assistant. Answer questions concisely and accurately."
20
  elif self.task == "Text Generation":
21
+ return "You are a creative AI writer. Generate engaging and coherent text based on the input."
22
  elif self.task == "Code Completion":
23
+ return "You are a coding assistant. Complete code snippets correctly without explanations."