File size: 9,951 Bytes
8cc5c82
bdca525
b4ecb60
 
5160420
8cc5c82
b4ecb60
 
5160420
b4ecb60
8cc5c82
1719ee5
e4ef2eb
1719ee5
b4ecb60
5160420
e4ef2eb
6aec7fd
 
581c860
 
 
04d7d7f
d24a753
581c860
 
b4ecb60
f788c15
 
1719ee5
5f43529
 
 
 
e9e9e0c
 
 
5f43529
e9e9e0c
5f43529
5160420
e4ef2eb
 
e9e9e0c
94b5c59
e4ef2eb
b4ecb60
4ffa911
5f43529
 
 
 
 
1719ee5
e4ef2eb
581c860
 
e4ef2eb
 
 
94b5c59
581c860
 
 
e9e9e0c
 
 
 
 
 
1719ee5
e4ef2eb
 
d24a753
 
8cc5c82
0248731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a14fb3
 
bb7bcb3
 
e4c419d
 
bb7bcb3
3a14fb3
8a9dfc3
 
 
 
 
da328b0
8a9dfc3
da328b0
8a9dfc3
 
 
 
2a738ad
8a9dfc3
 
 
 
2a738ad
8a9dfc3
 
 
 
 
 
cdc2c51
23cd50b
 
 
cdc2c51
8a9dfc3
cd05ce1
59c37d4
 
 
 
 
 
 
8a9dfc3
59c37d4
4db8e12
 
 
 
 
 
 
 
 
 
59c37d4
8a9dfc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87122c2
8a9dfc3
 
 
 
 
 
 
 
 
 
 
 
 
cdc2c51
8a9dfc3
 
 
 
 
 
 
 
 
 
 
 
 
cdc2c51
8a9dfc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc2c51
8a9dfc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6e2c9
3a14fb3
1719ee5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import gradio as gr
from ._model import LlamaCppGemmaModel
from ._prompts import PromptManager


class GradioChat:
    """
    A class that handles the chat interface for the Gemma model.

    Features:
    - A Gradio-based chatbot UI.
    - Dynamically loads models based on user selection.
    - Dynamically updates tasks using PromptManager.
    - Uses Gemma (llama.cpp) for generating responses.
    """

    def __init__(self, model_options: list[str], task_options: list[str]):
        self.model_options = model_options
        self.task_options = task_options

        self.current_model_name = "gemma-3b"
        self.current_task_name = "Question Answering"
        self.prompt_manager = self._load_task(self.current_task_name)

        # Cache.
        self.models_cache = {}

        self.model = self._load_model("gemma-3b")

    def _load_model(self, model_name: str):
        """Loads the model dynamically when switching models, with caching."""
        if model_name in self.models_cache:
            return self.models_cache[model_name]

        model = LlamaCppGemmaModel(name=model_name).load_model(
            system_prompt=self.prompt_manager.get_system_prompt()
        )
        self.models_cache[model_name] = model
        self.current_model_name = model_name
        return model

    def _load_task(self, task_name: str):
        """Loads the task dynamically when switching tasks."""
        self.current_task_name = task_name
        return PromptManager(task=task_name)

    def _chat(self):
        def chat_fn(message, history, selected_model, selected_task):
            # Lazy load model on first use
            if self.model is None:
                self.model = self._load_model(self.current_model_name)

            # Reload model if changed, using cache when possible
            if selected_model != self.current_model_name:
                self.model = self._load_model(selected_model)
                # Clear message history when model changes
                self.model.messages = []

            # Reload task if changed
            if selected_task != self.current_task_name:
                self.prompt_manager = self._load_task(selected_task)
                # Clear message history when task changes
                if self.model:
                    self.model.messages = []
                    self.model.messages = [
                        {
                            "role": "system",
                            "content": self.prompt_manager.get_system_prompt(),
                        }
                    ]

            # Generate response using updated model & prompt manager
            prompt = self.prompt_manager.get_prompt(user_input=message)
            response_stream = self.model.generate_response(prompt)
            yield from response_stream

        def _get_examples(task):
            # Examples for each task type
            examples = {
                "Question Answering": [
                    "What is quantum computing?",
                    "How do neural networks work?",
                    "Explain climate change in simple terms.",
                ],
                "Text Generation": [
                    "Once upon a time in a distant galaxy...",
                    "The abandoned house at the end of the street had...",
                    "In the year 2150, humanity discovered...",
                ],
                "Code Completion": [
                    "def fibonacci(n):",
                    "class BinarySearchInAList:",
                    "async def fetch_data(url):",
                ],
            }
            return examples.get(task)

        def _update_examples(task):
            """Updates the examples based on the selected task."""
            examples = _get_examples(task)
            return gr.Dataset(samples=[[example] for example in examples])

        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column(scale=3):  # Sidebar column
                    with gr.Accordion(
                        "Basic Settings ⚙️", open=False
                    ):  # Make the sidebar foldable
                        gr.Markdown(
                            "## Google Gemma Models: lightweight, state-of-the-art open models from Google"
                        )
                        task_dropdown = gr.Dropdown(
                            choices=self.task_options,
                            value=self.current_task_name,
                            label="Select Task",
                        )
                        model_dropdown = gr.Dropdown(
                            choices=self.model_options,
                            value=self.current_model_name,
                            label="Select Gemma Model",
                        )
                    chat_interface = gr.ChatInterface(
                        chat_fn,
                        additional_inputs=[model_dropdown, task_dropdown],
                        textbox=gr.Textbox(
                            placeholder="Ask me something...", container=False
                        ),
                    )
                    gr.Markdown(
                        "Medium Blog Post: [Gemma Chat Interface Blog](https://medium.com/@aadyachinubhai/introducing-the-gemma-chat-interface-your-ai-powered-chat-companion-a77fc609e51a)"
                    )

                with gr.Column(scale=1):
                    with gr.Accordion("Important Pointers", open=False):
                        gr.Markdown(
                            """
                        ## Pointers
                        
                        - First response after model change will be slower (model loading lazily).
                        - Switching models clears chat history.
                        - Larger models need more memory but give better results.
                        """
                        )
                    examples_list = gr.Examples(
                        examples=[
                            [example]
                            for example in _get_examples(self.current_task_name)
                        ],
                        inputs=chat_interface.textbox,
                    )
                    task_dropdown.change(
                        _update_examples, task_dropdown, examples_list.dataset
                    )

                    with gr.Accordion("Model Configuration ⚙️", open=False):
                        temperature_slider = gr.Slider(
                            minimum=0.1,
                            maximum=2,
                            value=self.model.temperature,
                            label="Temperature",
                        )
                        gr.Markdown(
                            "**Temperature:** Lower values make the output more deterministic."
                        )
                        temperature_slider.change(
                            fn=lambda temp: setattr(
                                self.model, "temperature", temp
                            ),
                            inputs=temperature_slider,
                        )

                        top_p_slider = gr.Slider(
                            minimum=0.1,
                            maximum=1.0,
                            value=self.model.top_p,
                            label="Top P",
                        )
                        gr.Markdown(
                            "**Top P:** Lower values make the output more focused."
                        )
                        top_p_slider.change(
                            fn=lambda top_p: setattr(self.model, "top_p", top_p),
                            inputs=top_p_slider,
                        )

                        top_k_slider = gr.Slider(
                            minimum=1,
                            maximum=100,
                            value=self.model.top_k,
                            label="Top K",
                        )
                        gr.Markdown(
                            "**Top K:** Lower values make the output more focused."
                        )
                        top_k_slider.change(
                            fn=lambda top_k: setattr(self.model, "top_k", top_k),
                            inputs=top_k_slider,
                        )

                        repetition_penalty_slider = gr.Slider(
                            minimum=1.0,
                            maximum=2.0,
                            value=self.model.repeat_penalty,
                            label="Repetition Penalty",
                        )
                        gr.Markdown(
                            "**Repetition Penalty:** Penalizes repeated tokens to reduce repetition in the output."
                        )
                        repetition_penalty_slider.change(
                            fn=lambda penalty: setattr(
                                self.model, "repeat_penalty", penalty
                            ),
                            inputs=repetition_penalty_slider,
                        )

                        max_tokens_slider = gr.Slider(
                            minimum=512,
                            maximum=2048,
                            value=self.model.max_tokens,
                            label="Max Tokens",
                        )
                        gr.Markdown(
                            "**Max Tokens:** Sets the maximum number of tokens the model can generate in one response."
                        )
                        max_tokens_slider.change(
                            fn=lambda max_tokens: setattr(
                                self.model, "max_tokens", max_tokens
                            ),
                            inputs=max_tokens_slider,
                        )

        demo.launch()

    def run(self):
        self._chat()