Spaces:
Sleeping
Sleeping
Add model config sliders
Browse files- app.py +34 -4
- gemmademo/_model.py +9 -9
app.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
# Interface all the functions from gemmademo.
|
| 2 |
-
# Implement login functionality in the side bar.
|
| 3 |
# Implement a task selector in the side bar.
|
| 4 |
-
# Interface all the functions from gemmademo.
|
| 5 |
# Add a button to clear the chat history.
|
| 6 |
-
|
| 7 |
import streamlit as st
|
| 8 |
from gemmademo import (
|
| 9 |
LlamaCppGemmaModel,
|
|
@@ -25,6 +22,10 @@ def main():
|
|
| 25 |
st.session_state.selected_model = "gemma-2b-it"
|
| 26 |
if "selected_task" not in st.session_state:
|
| 27 |
st.session_state.selected_task = "Question Answering"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Sidebar for login and configuration
|
| 30 |
with st.sidebar:
|
|
@@ -56,6 +57,31 @@ def main():
|
|
| 56 |
st.session_state.selected_task = selected_task
|
| 57 |
st.rerun()
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Clear chat history button
|
| 60 |
if st.button("Clear Chat History"):
|
| 61 |
if "chat_instance" in st.session_state:
|
|
@@ -65,7 +91,11 @@ def main():
|
|
| 65 |
# Main content area
|
| 66 |
# Initialize model with the selected configuration
|
| 67 |
model_name = st.session_state.selected_model
|
| 68 |
-
model = LlamaCppGemmaModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Load model (will use cached version if available)
|
| 71 |
with st.spinner(f"Loading {model_name}..."):
|
|
|
|
| 1 |
# Interface all the functions from gemmademo.
|
|
|
|
| 2 |
# Implement a task selector in the side bar.
|
|
|
|
| 3 |
# Add a button to clear the chat history.
|
|
|
|
| 4 |
import streamlit as st
|
| 5 |
from gemmademo import (
|
| 6 |
LlamaCppGemmaModel,
|
|
|
|
| 22 |
st.session_state.selected_model = "gemma-2b-it"
|
| 23 |
if "selected_task" not in st.session_state:
|
| 24 |
st.session_state.selected_task = "Question Answering"
|
| 25 |
+
if "max_tokens" not in st.session_state:
|
| 26 |
+
st.session_state.max_tokens = 512
|
| 27 |
+
if "temperature" not in st.session_state:
|
| 28 |
+
st.session_state.temperature = 0.7
|
| 29 |
|
| 30 |
# Sidebar for login and configuration
|
| 31 |
with st.sidebar:
|
|
|
|
| 57 |
st.session_state.selected_task = selected_task
|
| 58 |
st.rerun()
|
| 59 |
|
| 60 |
+
# Model Config Selection
|
| 61 |
+
new_max_tokens_value = st.slider(
|
| 62 |
+
"Max Tokens",
|
| 63 |
+
min_value=1,
|
| 64 |
+
max_value=4096,
|
| 65 |
+
value=st.session_state.max_tokens,
|
| 66 |
+
step=1,
|
| 67 |
+
)
|
| 68 |
+
# After setting the slider values
|
| 69 |
+
if st.session_state.max_tokens != new_max_tokens_value:
|
| 70 |
+
st.session_state.max_tokens = new_max_tokens_value
|
| 71 |
+
st.rerun()
|
| 72 |
+
|
| 73 |
+
new_temperature_value = st.slider(
|
| 74 |
+
"Temperature",
|
| 75 |
+
min_value=0.0,
|
| 76 |
+
max_value=1.0,
|
| 77 |
+
value=st.session_state.temperature,
|
| 78 |
+
step=0.01,
|
| 79 |
+
)
|
| 80 |
+
# After setting the slider values
|
| 81 |
+
if st.session_state.temperature != new_temperature_value:
|
| 82 |
+
st.session_state.temperature = new_temperature_value
|
| 83 |
+
st.rerun()
|
| 84 |
+
|
| 85 |
# Clear chat history button
|
| 86 |
if st.button("Clear Chat History"):
|
| 87 |
if "chat_instance" in st.session_state:
|
|
|
|
| 91 |
# Main content area
|
| 92 |
# Initialize model with the selected configuration
|
| 93 |
model_name = st.session_state.selected_model
|
| 94 |
+
model = LlamaCppGemmaModel(
|
| 95 |
+
name=model_name,
|
| 96 |
+
max_tokens=st.session_state.max_tokens,
|
| 97 |
+
temperature=st.session_state.temperature,
|
| 98 |
+
)
|
| 99 |
|
| 100 |
# Load model (will use cached version if available)
|
| 101 |
with st.spinner(f"Loading {model_name}..."):
|
gemmademo/_model.py
CHANGED
|
@@ -51,7 +51,9 @@ class LlamaCppGemmaModel:
|
|
| 51 |
},
|
| 52 |
}
|
| 53 |
|
| 54 |
-
def __init__(
|
|
|
|
|
|
|
| 55 |
"""
|
| 56 |
Initialize the model instance.
|
| 57 |
|
|
@@ -60,8 +62,10 @@ class LlamaCppGemmaModel:
|
|
| 60 |
"""
|
| 61 |
self.name = name
|
| 62 |
self.model = None # Instance of Llama from llama.cpp
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
def load_model(self,
|
| 65 |
"""
|
| 66 |
Load the model and cache it in Streamlit's session state.
|
| 67 |
If the model file does not exist, it will be downloaded to the models/ directory.
|
|
@@ -104,7 +108,7 @@ class LlamaCppGemmaModel:
|
|
| 104 |
with st.spinner(f"Loading {self.name}..."):
|
| 105 |
st.session_state[model_key] = Llama(
|
| 106 |
model_path=model_path,
|
| 107 |
-
n_threads=
|
| 108 |
n_ctx=n_ctx,
|
| 109 |
n_gpu_layers=n_gpu_layers,
|
| 110 |
)
|
|
@@ -114,9 +118,6 @@ class LlamaCppGemmaModel:
|
|
| 114 |
def generate_response(
|
| 115 |
self,
|
| 116 |
prompt: str,
|
| 117 |
-
max_tokens: int = 512,
|
| 118 |
-
temperature: float = 0.7,
|
| 119 |
-
**kwargs,
|
| 120 |
) -> str:
|
| 121 |
"""
|
| 122 |
Generate a response using the llama.cpp model.
|
|
@@ -136,9 +137,8 @@ class LlamaCppGemmaModel:
|
|
| 136 |
# Call the llama.cpp model with the provided parameters.
|
| 137 |
response = self.model(
|
| 138 |
prompt,
|
| 139 |
-
max_tokens=max_tokens,
|
| 140 |
-
temperature=temperature,
|
| 141 |
-
**kwargs,
|
| 142 |
)
|
| 143 |
generated_text = response["choices"][0]["text"]
|
| 144 |
return generated_text.strip()
|
|
|
|
| 51 |
},
|
| 52 |
}
|
| 53 |
|
| 54 |
+
def __init__(
|
| 55 |
+
self, name: str = "gemma-2b", max_tokens: int = 512, temperature: float = 0.7
|
| 56 |
+
):
|
| 57 |
"""
|
| 58 |
Initialize the model instance.
|
| 59 |
|
|
|
|
| 62 |
"""
|
| 63 |
self.name = name
|
| 64 |
self.model = None # Instance of Llama from llama.cpp
|
| 65 |
+
self.max_tokens = max_tokens
|
| 66 |
+
self.temperature = temperature
|
| 67 |
|
| 68 |
+
def load_model(self, n_ctx: int = 2048, n_gpu_layers: int = 0):
|
| 69 |
"""
|
| 70 |
Load the model and cache it in Streamlit's session state.
|
| 71 |
If the model file does not exist, it will be downloaded to the models/ directory.
|
|
|
|
| 108 |
with st.spinner(f"Loading {self.name}..."):
|
| 109 |
st.session_state[model_key] = Llama(
|
| 110 |
model_path=model_path,
|
| 111 |
+
n_threads=os.cpu_count(),
|
| 112 |
n_ctx=n_ctx,
|
| 113 |
n_gpu_layers=n_gpu_layers,
|
| 114 |
)
|
|
|
|
| 118 |
def generate_response(
|
| 119 |
self,
|
| 120 |
prompt: str,
|
|
|
|
|
|
|
|
|
|
| 121 |
) -> str:
|
| 122 |
"""
|
| 123 |
Generate a response using the llama.cpp model.
|
|
|
|
| 137 |
# Call the llama.cpp model with the provided parameters.
|
| 138 |
response = self.model(
|
| 139 |
prompt,
|
| 140 |
+
max_tokens=self.max_tokens,
|
| 141 |
+
temperature=self.temperature,
|
|
|
|
| 142 |
)
|
| 143 |
generated_text = response["choices"][0]["text"]
|
| 144 |
return generated_text.strip()
|