Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from ._model import LlamaCppGemmaModel | |
| from ._prompts import PromptManager | |
| class GradioChat: | |
| """ | |
| A class that handles the chat interface for the Gemma model. | |
| Features: | |
| - A Gradio-based chatbot UI. | |
| - Dynamically loads models based on user selection. | |
| - Dynamically updates tasks using PromptManager. | |
| - Uses Gemma (llama.cpp) for generating responses. | |
| """ | |
| def __init__(self, model_options: list[str], task_options: list[str]): | |
| self.model_options = model_options | |
| self.task_options = task_options | |
| self.current_model_name = "gemma-3b" | |
| self.current_task_name = "Question Answering" | |
| self.prompt_manager = self._load_task(self.current_task_name) | |
| # Cache. | |
| self.models_cache = {} | |
| self.model = self._load_model("gemma-3b") | |
| def _load_model(self, model_name: str): | |
| """Loads the model dynamically when switching models, with caching.""" | |
| if model_name in self.models_cache: | |
| return self.models_cache[model_name] | |
| model = LlamaCppGemmaModel(name=model_name).load_model( | |
| system_prompt=self.prompt_manager.get_system_prompt() | |
| ) | |
| self.models_cache[model_name] = model | |
| self.current_model_name = model_name | |
| return model | |
| def _load_task(self, task_name: str): | |
| """Loads the task dynamically when switching tasks.""" | |
| self.current_task_name = task_name | |
| return PromptManager(task=task_name) | |
| def _chat(self): | |
| def chat_fn(message, history, selected_model, selected_task): | |
| # Lazy load model on first use | |
| if self.model is None: | |
| self.model = self._load_model(self.current_model_name) | |
| # Reload model if changed, using cache when possible | |
| if selected_model != self.current_model_name: | |
| self.model = self._load_model(selected_model) | |
| # Clear message history when model changes | |
| self.model.messages = [] | |
| # Reload task if changed | |
| if selected_task != self.current_task_name: | |
| self.prompt_manager = self._load_task(selected_task) | |
| # Clear message history when task changes | |
| if self.model: | |
| self.model.messages = [] | |
| self.model.messages = [ | |
| { | |
| "role": "system", | |
| "content": self.prompt_manager.get_system_prompt(), | |
| } | |
| ] | |
| # Generate response using updated model & prompt manager | |
| prompt = self.prompt_manager.get_prompt(user_input=message) | |
| response_stream = self.model.generate_response(prompt) | |
| yield from response_stream | |
| def _get_examples(task): | |
| # Examples for each task type | |
| examples = { | |
| "Question Answering": [ | |
| "What is quantum computing?", | |
| "How do neural networks work?", | |
| "Explain climate change in simple terms.", | |
| ], | |
| "Text Generation": [ | |
| "Once upon a time in a distant galaxy...", | |
| "The abandoned house at the end of the street had...", | |
| "In the year 2150, humanity discovered...", | |
| ], | |
| "Code Completion": [ | |
| "def fibonacci(n):", | |
| "class BinarySearchInAList:", | |
| "async def fetch_data(url):", | |
| ], | |
| } | |
| return examples.get(task) | |
| def _update_examples(task): | |
| """Updates the examples based on the selected task.""" | |
| examples = _get_examples(task) | |
| return gr.Dataset(samples=[[example] for example in examples]) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=3): # Sidebar column | |
| with gr.Accordion( | |
| "Basic Settings ⚙️", open=False | |
| ): # Make the sidebar foldable | |
| gr.Markdown( | |
| "## Google Gemma Models: lightweight, state-of-the-art open models from Google" | |
| ) | |
| task_dropdown = gr.Dropdown( | |
| choices=self.task_options, | |
| value=self.current_task_name, | |
| label="Select Task", | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=self.model_options, | |
| value=self.current_model_name, | |
| label="Select Gemma Model", | |
| ) | |
| chat_interface = gr.ChatInterface( | |
| chat_fn, | |
| additional_inputs=[model_dropdown, task_dropdown], | |
| textbox=gr.Textbox( | |
| placeholder="Ask me something...", container=False | |
| ), | |
| ) | |
| gr.Markdown( | |
| "Medium Blog Post: [Gemma Chat Interface Blog](https://medium.com/@aadyachinubhai/introducing-the-gemma-chat-interface-your-ai-powered-chat-companion-a77fc609e51a)" | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Important Pointers", open=False): | |
| gr.Markdown( | |
| """ | |
| ## Pointers | |
| - First response after model change will be slower (model loading lazily). | |
| - Switching models clears chat history. | |
| - Larger models need more memory but give better results. | |
| """ | |
| ) | |
| examples_list = gr.Examples( | |
| examples=[ | |
| [example] | |
| for example in _get_examples(self.current_task_name) | |
| ], | |
| inputs=chat_interface.textbox, | |
| ) | |
| task_dropdown.change( | |
| _update_examples, task_dropdown, examples_list.dataset | |
| ) | |
| with gr.Accordion("Model Configuration ⚙️", open=False): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2, | |
| value=self.model.temperature, | |
| label="Temperature", | |
| ) | |
| gr.Markdown( | |
| "**Temperature:** Lower values make the output more deterministic." | |
| ) | |
| temperature_slider.change( | |
| fn=lambda temp: setattr( | |
| self.model, "temperature", temp | |
| ), | |
| inputs=temperature_slider, | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=self.model.top_p, | |
| label="Top P", | |
| ) | |
| gr.Markdown( | |
| "**Top P:** Lower values make the output more focused." | |
| ) | |
| top_p_slider.change( | |
| fn=lambda top_p: setattr(self.model, "top_p", top_p), | |
| inputs=top_p_slider, | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=self.model.top_k, | |
| label="Top K", | |
| ) | |
| gr.Markdown( | |
| "**Top K:** Lower values make the output more focused." | |
| ) | |
| top_k_slider.change( | |
| fn=lambda top_k: setattr(self.model, "top_k", top_k), | |
| inputs=top_k_slider, | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=self.model.repeat_penalty, | |
| label="Repetition Penalty", | |
| ) | |
| gr.Markdown( | |
| "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output." | |
| ) | |
| repetition_penalty_slider.change( | |
| fn=lambda penalty: setattr( | |
| self.model, "repeat_penalty", penalty | |
| ), | |
| inputs=repetition_penalty_slider, | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=512, | |
| maximum=2048, | |
| value=self.model.max_tokens, | |
| label="Max Tokens", | |
| ) | |
| gr.Markdown( | |
| "**Max Tokens:** Sets the maximum number of tokens the model can generate in one response." | |
| ) | |
| max_tokens_slider.change( | |
| fn=lambda max_tokens: setattr( | |
| self.model, "max_tokens", max_tokens | |
| ), | |
| inputs=max_tokens_slider, | |
| ) | |
| demo.launch() | |
| def run(self): | |
| self._chat() | |