import gradio as gr from ._model import LlamaCppGemmaModel from ._prompts import PromptManager class GradioChat: """ A class that handles the chat interface for the Gemma model. Features: - A Gradio-based chatbot UI. - Dynamically loads models based on user selection. - Dynamically updates tasks using PromptManager. - Uses Gemma (llama.cpp) for generating responses. """ def __init__(self, model_options: list[str], task_options: list[str]): self.model_options = model_options self.task_options = task_options self.current_model_name = "gemma-3b" self.current_task_name = "Question Answering" self.prompt_manager = self._load_task(self.current_task_name) # Cache. self.models_cache = {} self.model = self._load_model("gemma-3b") def _load_model(self, model_name: str): """Loads the model dynamically when switching models, with caching.""" if model_name in self.models_cache: return self.models_cache[model_name] model = LlamaCppGemmaModel(name=model_name).load_model( system_prompt=self.prompt_manager.get_system_prompt() ) self.models_cache[model_name] = model self.current_model_name = model_name return model def _load_task(self, task_name: str): """Loads the task dynamically when switching tasks.""" self.current_task_name = task_name return PromptManager(task=task_name) def _chat(self): def chat_fn(message, history, selected_model, selected_task): # Lazy load model on first use if self.model is None: self.model = self._load_model(self.current_model_name) # Reload model if changed, using cache when possible if selected_model != self.current_model_name: self.model = self._load_model(selected_model) # Clear message history when model changes self.model.messages = [] # Reload task if changed if selected_task != self.current_task_name: self.prompt_manager = self._load_task(selected_task) # Clear message history when task changes if self.model: self.model.messages = [] self.model.messages = [ { "role": "system", "content": self.prompt_manager.get_system_prompt(), } ] # Generate response using updated model & prompt manager prompt = self.prompt_manager.get_prompt(user_input=message) response_stream = self.model.generate_response(prompt) yield from response_stream def _get_examples(task): # Examples for each task type examples = { "Question Answering": [ "What is quantum computing?", "How do neural networks work?", "Explain climate change in simple terms.", ], "Text Generation": [ "Once upon a time in a distant galaxy...", "The abandoned house at the end of the street had...", "In the year 2150, humanity discovered...", ], "Code Completion": [ "def fibonacci(n):", "class BinarySearchInAList:", "async def fetch_data(url):", ], } return examples.get(task) def _update_examples(task): """Updates the examples based on the selected task.""" examples = _get_examples(task) return gr.Dataset(samples=[[example] for example in examples]) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=3): # Sidebar column with gr.Accordion( "Basic Settings ⚙️", open=False ): # Make the sidebar foldable gr.Markdown( "## Google Gemma Models: lightweight, state-of-the-art open models from Google" ) task_dropdown = gr.Dropdown( choices=self.task_options, value=self.current_task_name, label="Select Task", ) model_dropdown = gr.Dropdown( choices=self.model_options, value=self.current_model_name, label="Select Gemma Model", ) chat_interface = gr.ChatInterface( chat_fn, additional_inputs=[model_dropdown, task_dropdown], textbox=gr.Textbox( placeholder="Ask me something...", container=False ), ) gr.Markdown( "Medium Blog Post: [Gemma Chat Interface Blog](https://medium.com/@aadyachinubhai/introducing-the-gemma-chat-interface-your-ai-powered-chat-companion-a77fc609e51a)" ) with gr.Column(scale=1): with gr.Accordion("Important Pointers", open=False): gr.Markdown( """ ## Pointers - First response after model change will be slower (model loading lazily). - Switching models clears chat history. - Larger models need more memory but give better results. """ ) examples_list = gr.Examples( examples=[ [example] for example in _get_examples(self.current_task_name) ], inputs=chat_interface.textbox, ) task_dropdown.change( _update_examples, task_dropdown, examples_list.dataset ) with gr.Accordion("Model Configuration ⚙️", open=False): temperature_slider = gr.Slider( minimum=0.1, maximum=2, value=self.model.temperature, label="Temperature", ) gr.Markdown( "**Temperature:** Lower values make the output more deterministic." ) temperature_slider.change( fn=lambda temp: setattr( self.model, "temperature", temp ), inputs=temperature_slider, ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=self.model.top_p, label="Top P", ) gr.Markdown( "**Top P:** Lower values make the output more focused." ) top_p_slider.change( fn=lambda top_p: setattr(self.model, "top_p", top_p), inputs=top_p_slider, ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=self.model.top_k, label="Top K", ) gr.Markdown( "**Top K:** Lower values make the output more focused." ) top_k_slider.change( fn=lambda top_k: setattr(self.model, "top_k", top_k), inputs=top_k_slider, ) repetition_penalty_slider = gr.Slider( minimum=1.0, maximum=2.0, value=self.model.repeat_penalty, label="Repetition Penalty", ) gr.Markdown( "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output." ) repetition_penalty_slider.change( fn=lambda penalty: setattr( self.model, "repeat_penalty", penalty ), inputs=repetition_penalty_slider, ) max_tokens_slider = gr.Slider( minimum=512, maximum=2048, value=self.model.max_tokens, label="Max Tokens", ) gr.Markdown( "**Max Tokens:** Sets the maximum number of tokens the model can generate in one response." ) max_tokens_slider.change( fn=lambda max_tokens: setattr( self.model, "max_tokens", max_tokens ), inputs=max_tokens_slider, ) demo.launch() def run(self): self._chat()