Spaces:
Sleeping
Sleeping
File size: 9,951 Bytes
8cc5c82 bdca525 b4ecb60 5160420 8cc5c82 b4ecb60 5160420 b4ecb60 8cc5c82 1719ee5 e4ef2eb 1719ee5 b4ecb60 5160420 e4ef2eb 6aec7fd 581c860 04d7d7f d24a753 581c860 b4ecb60 f788c15 1719ee5 5f43529 e9e9e0c 5f43529 e9e9e0c 5f43529 5160420 e4ef2eb e9e9e0c 94b5c59 e4ef2eb b4ecb60 4ffa911 5f43529 1719ee5 e4ef2eb 581c860 e4ef2eb 94b5c59 581c860 e9e9e0c 1719ee5 e4ef2eb d24a753 8cc5c82 0248731 3a14fb3 bb7bcb3 e4c419d bb7bcb3 3a14fb3 8a9dfc3 da328b0 8a9dfc3 da328b0 8a9dfc3 2a738ad 8a9dfc3 2a738ad 8a9dfc3 cdc2c51 23cd50b cdc2c51 8a9dfc3 cd05ce1 59c37d4 8a9dfc3 59c37d4 4db8e12 59c37d4 8a9dfc3 87122c2 8a9dfc3 cdc2c51 8a9dfc3 cdc2c51 8a9dfc3 cdc2c51 8a9dfc3 fd6e2c9 3a14fb3 1719ee5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | import gradio as gr
from ._model import LlamaCppGemmaModel
from ._prompts import PromptManager
class GradioChat:
"""
A class that handles the chat interface for the Gemma model.
Features:
- A Gradio-based chatbot UI.
- Dynamically loads models based on user selection.
- Dynamically updates tasks using PromptManager.
- Uses Gemma (llama.cpp) for generating responses.
"""
def __init__(self, model_options: list[str], task_options: list[str]):
self.model_options = model_options
self.task_options = task_options
self.current_model_name = "gemma-3b"
self.current_task_name = "Question Answering"
self.prompt_manager = self._load_task(self.current_task_name)
# Cache.
self.models_cache = {}
self.model = self._load_model("gemma-3b")
def _load_model(self, model_name: str):
"""Loads the model dynamically when switching models, with caching."""
if model_name in self.models_cache:
return self.models_cache[model_name]
model = LlamaCppGemmaModel(name=model_name).load_model(
system_prompt=self.prompt_manager.get_system_prompt()
)
self.models_cache[model_name] = model
self.current_model_name = model_name
return model
def _load_task(self, task_name: str):
"""Loads the task dynamically when switching tasks."""
self.current_task_name = task_name
return PromptManager(task=task_name)
def _chat(self):
def chat_fn(message, history, selected_model, selected_task):
# Lazy load model on first use
if self.model is None:
self.model = self._load_model(self.current_model_name)
# Reload model if changed, using cache when possible
if selected_model != self.current_model_name:
self.model = self._load_model(selected_model)
# Clear message history when model changes
self.model.messages = []
# Reload task if changed
if selected_task != self.current_task_name:
self.prompt_manager = self._load_task(selected_task)
# Clear message history when task changes
if self.model:
self.model.messages = []
self.model.messages = [
{
"role": "system",
"content": self.prompt_manager.get_system_prompt(),
}
]
# Generate response using updated model & prompt manager
prompt = self.prompt_manager.get_prompt(user_input=message)
response_stream = self.model.generate_response(prompt)
yield from response_stream
def _get_examples(task):
# Examples for each task type
examples = {
"Question Answering": [
"What is quantum computing?",
"How do neural networks work?",
"Explain climate change in simple terms.",
],
"Text Generation": [
"Once upon a time in a distant galaxy...",
"The abandoned house at the end of the street had...",
"In the year 2150, humanity discovered...",
],
"Code Completion": [
"def fibonacci(n):",
"class BinarySearchInAList:",
"async def fetch_data(url):",
],
}
return examples.get(task)
def _update_examples(task):
"""Updates the examples based on the selected task."""
examples = _get_examples(task)
return gr.Dataset(samples=[[example] for example in examples])
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=3): # Sidebar column
with gr.Accordion(
"Basic Settings ⚙️", open=False
): # Make the sidebar foldable
gr.Markdown(
"## Google Gemma Models: lightweight, state-of-the-art open models from Google"
)
task_dropdown = gr.Dropdown(
choices=self.task_options,
value=self.current_task_name,
label="Select Task",
)
model_dropdown = gr.Dropdown(
choices=self.model_options,
value=self.current_model_name,
label="Select Gemma Model",
)
chat_interface = gr.ChatInterface(
chat_fn,
additional_inputs=[model_dropdown, task_dropdown],
textbox=gr.Textbox(
placeholder="Ask me something...", container=False
),
)
gr.Markdown(
"Medium Blog Post: [Gemma Chat Interface Blog](https://medium.com/@aadyachinubhai/introducing-the-gemma-chat-interface-your-ai-powered-chat-companion-a77fc609e51a)"
)
with gr.Column(scale=1):
with gr.Accordion("Important Pointers", open=False):
gr.Markdown(
"""
## Pointers
- First response after model change will be slower (model loading lazily).
- Switching models clears chat history.
- Larger models need more memory but give better results.
"""
)
examples_list = gr.Examples(
examples=[
[example]
for example in _get_examples(self.current_task_name)
],
inputs=chat_interface.textbox,
)
task_dropdown.change(
_update_examples, task_dropdown, examples_list.dataset
)
with gr.Accordion("Model Configuration ⚙️", open=False):
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2,
value=self.model.temperature,
label="Temperature",
)
gr.Markdown(
"**Temperature:** Lower values make the output more deterministic."
)
temperature_slider.change(
fn=lambda temp: setattr(
self.model, "temperature", temp
),
inputs=temperature_slider,
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=self.model.top_p,
label="Top P",
)
gr.Markdown(
"**Top P:** Lower values make the output more focused."
)
top_p_slider.change(
fn=lambda top_p: setattr(self.model, "top_p", top_p),
inputs=top_p_slider,
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=self.model.top_k,
label="Top K",
)
gr.Markdown(
"**Top K:** Lower values make the output more focused."
)
top_k_slider.change(
fn=lambda top_k: setattr(self.model, "top_k", top_k),
inputs=top_k_slider,
)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=self.model.repeat_penalty,
label="Repetition Penalty",
)
gr.Markdown(
"**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output."
)
repetition_penalty_slider.change(
fn=lambda penalty: setattr(
self.model, "repeat_penalty", penalty
),
inputs=repetition_penalty_slider,
)
max_tokens_slider = gr.Slider(
minimum=512,
maximum=2048,
value=self.model.max_tokens,
label="Max Tokens",
)
gr.Markdown(
"**Max Tokens:** Sets the maximum number of tokens the model can generate in one response."
)
max_tokens_slider.change(
fn=lambda max_tokens: setattr(
self.model, "max_tokens", max_tokens
),
inputs=max_tokens_slider,
)
demo.launch()
def run(self):
self._chat()
|