aadya1762 commited on
Commit
c1e7456
·
1 Parent(s): 296ef11

Add model config sliders

Browse files
Files changed (2) hide show
  1. app.py +34 -4
  2. gemmademo/_model.py +9 -9
app.py CHANGED
@@ -1,9 +1,6 @@
1
  # Interface all the functions from gemmademo.
2
- # Implement login functionality in the side bar.
3
  # Implement a task selector in the side bar.
4
- # Interface all the functions from gemmademo.
5
  # Add a button to clear the chat history.
6
-
7
  import streamlit as st
8
  from gemmademo import (
9
  LlamaCppGemmaModel,
@@ -25,6 +22,10 @@ def main():
25
  st.session_state.selected_model = "gemma-2b-it"
26
  if "selected_task" not in st.session_state:
27
  st.session_state.selected_task = "Question Answering"
 
 
 
 
28
 
29
  # Sidebar for login and configuration
30
  with st.sidebar:
@@ -56,6 +57,31 @@ def main():
56
  st.session_state.selected_task = selected_task
57
  st.rerun()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Clear chat history button
60
  if st.button("Clear Chat History"):
61
  if "chat_instance" in st.session_state:
@@ -65,7 +91,11 @@ def main():
65
  # Main content area
66
  # Initialize model with the selected configuration
67
  model_name = st.session_state.selected_model
68
- model = LlamaCppGemmaModel(name=model_name)
 
 
 
 
69
 
70
  # Load model (will use cached version if available)
71
  with st.spinner(f"Loading {model_name}..."):
 
1
  # Interface all the functions from gemmademo.
 
2
  # Implement a task selector in the side bar.
 
3
  # Add a button to clear the chat history.
 
4
  import streamlit as st
5
  from gemmademo import (
6
  LlamaCppGemmaModel,
 
22
  st.session_state.selected_model = "gemma-2b-it"
23
  if "selected_task" not in st.session_state:
24
  st.session_state.selected_task = "Question Answering"
25
+ if "max_tokens" not in st.session_state:
26
+ st.session_state.max_tokens = 512
27
+ if "temperature" not in st.session_state:
28
+ st.session_state.temperature = 0.7
29
 
30
  # Sidebar for login and configuration
31
  with st.sidebar:
 
57
  st.session_state.selected_task = selected_task
58
  st.rerun()
59
 
60
+ # Model Config Selection
61
+ new_max_tokens_value = st.slider(
62
+ "Max Tokens",
63
+ min_value=1,
64
+ max_value=4096,
65
+ value=st.session_state.max_tokens,
66
+ step=1,
67
+ )
68
+ # After setting the slider values
69
+ if st.session_state.max_tokens != new_max_tokens_value:
70
+ st.session_state.max_tokens = new_max_tokens_value
71
+ st.rerun()
72
+
73
+ new_temperature_value = st.slider(
74
+ "Temperature",
75
+ min_value=0.0,
76
+ max_value=1.0,
77
+ value=st.session_state.temperature,
78
+ step=0.01,
79
+ )
80
+ # After setting the slider values
81
+ if st.session_state.temperature != new_temperature_value:
82
+ st.session_state.temperature = new_temperature_value
83
+ st.rerun()
84
+
85
  # Clear chat history button
86
  if st.button("Clear Chat History"):
87
  if "chat_instance" in st.session_state:
 
91
  # Main content area
92
  # Initialize model with the selected configuration
93
  model_name = st.session_state.selected_model
94
+ model = LlamaCppGemmaModel(
95
+ name=model_name,
96
+ max_tokens=st.session_state.max_tokens,
97
+ temperature=st.session_state.temperature,
98
+ )
99
 
100
  # Load model (will use cached version if available)
101
  with st.spinner(f"Loading {model_name}..."):
gemmademo/_model.py CHANGED
@@ -51,7 +51,9 @@ class LlamaCppGemmaModel:
51
  },
52
  }
53
 
54
- def __init__(self, name: str = "gemma-2b"):
 
 
55
  """
56
  Initialize the model instance.
57
 
@@ -60,8 +62,10 @@ class LlamaCppGemmaModel:
60
  """
61
  self.name = name
62
  self.model = None # Instance of Llama from llama.cpp
 
 
63
 
64
- def load_model(self, n_threads: int = 2, n_ctx: int = 2048, n_gpu_layers: int = 0):
65
  """
66
  Load the model and cache it in Streamlit's session state.
67
  If the model file does not exist, it will be downloaded to the models/ directory.
@@ -104,7 +108,7 @@ class LlamaCppGemmaModel:
104
  with st.spinner(f"Loading {self.name}..."):
105
  st.session_state[model_key] = Llama(
106
  model_path=model_path,
107
- n_threads=n_threads,
108
  n_ctx=n_ctx,
109
  n_gpu_layers=n_gpu_layers,
110
  )
@@ -114,9 +118,6 @@ class LlamaCppGemmaModel:
114
  def generate_response(
115
  self,
116
  prompt: str,
117
- max_tokens: int = 512,
118
- temperature: float = 0.7,
119
- **kwargs,
120
  ) -> str:
121
  """
122
  Generate a response using the llama.cpp model.
@@ -136,9 +137,8 @@ class LlamaCppGemmaModel:
136
  # Call the llama.cpp model with the provided parameters.
137
  response = self.model(
138
  prompt,
139
- max_tokens=max_tokens,
140
- temperature=temperature,
141
- **kwargs,
142
  )
143
  generated_text = response["choices"][0]["text"]
144
  return generated_text.strip()
 
51
  },
52
  }
53
 
54
+ def __init__(
55
+ self, name: str = "gemma-2b", max_tokens: int = 512, temperature: float = 0.7
56
+ ):
57
  """
58
  Initialize the model instance.
59
 
 
62
  """
63
  self.name = name
64
  self.model = None # Instance of Llama from llama.cpp
65
+ self.max_tokens = max_tokens
66
+ self.temperature = temperature
67
 
68
+ def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
69
  """
70
  Load the model and cache it in Streamlit's session state.
71
  If the model file does not exist, it will be downloaded to the models/ directory.
 
108
  with st.spinner(f"Loading {self.name}..."):
109
  st.session_state[model_key] = Llama(
110
  model_path=model_path,
111
+ n_threads=os.cpu_count(),
112
  n_ctx=n_ctx,
113
  n_gpu_layers=n_gpu_layers,
114
  )
 
118
  def generate_response(
119
  self,
120
  prompt: str,
 
 
 
121
  ) -> str:
122
  """
123
  Generate a response using the llama.cpp model.
 
137
  # Call the llama.cpp model with the provided parameters.
138
  response = self.model(
139
  prompt,
140
+ max_tokens=self.max_tokens,
141
+ temperature=self.temperature,
 
142
  )
143
  generated_text = response["choices"][0]["text"]
144
  return generated_text.strip()