File size: 12,932 Bytes
fdf7bd6
21257b4
fdf7bd6
9aee162
fdf7bd6
21257b4
7e41311
21257b4
 
9aee162
21257b4
 
9aee162
 
 
 
21257b4
 
 
 
9aee162
 
21257b4
 
 
7e41311
 
 
21257b4
7e41311
9aee162
7e41311
21257b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aee162
 
 
 
 
 
 
 
 
 
 
 
 
 
21257b4
fdf7bd6
21257b4
9aee162
21257b4
 
 
 
9aee162
 
 
 
 
fdf7bd6
 
 
 
 
 
 
 
 
2cfd82c
fdf7bd6
 
 
 
 
 
 
 
 
2cfd82c
 
fdf7bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
21257b4
7e41311
21257b4
7e41311
 
9aee162
7e41311
 
fdf7bd6
 
 
 
7e41311
fdf7bd6
 
 
 
 
 
 
7e41311
fdf7bd6
 
 
 
7e41311
fdf7bd6
 
 
21257b4
fdf7bd6
7e41311
fdf7bd6
 
 
 
 
21257b4
fdf7bd6
 
 
 
 
 
 
 
 
 
21257b4
 
 
9aee162
21257b4
 
 
9aee162
 
21257b4
9aee162
 
 
 
 
 
 
 
21257b4
 
fdf7bd6
 
9aee162
 
 
 
 
 
 
 
 
fdf7bd6
21257b4
9aee162
 
21257b4
 
9aee162
 
 
 
 
 
 
21257b4
 
fdf7bd6
21257b4
 
fdf7bd6
 
 
 
21257b4
 
fdf7bd6
 
9aee162
fdf7bd6
 
21257b4
 
9aee162
21257b4
fdf7bd6
9aee162
fdf7bd6
7e41311
21257b4
fdf7bd6
 
 
21257b4
 
9aee162
21257b4
fdf7bd6
9aee162
 
 
 
 
fdf7bd6
 
 
21257b4
 
fdf7bd6
 
 
 
9aee162
 
 
 
7e41311
21257b4
fdf7bd6
9aee162
 
 
 
 
 
fdf7bd6
 
 
9aee162
 
21257b4
 
fdf7bd6
9aee162
 
fdf7bd6
21257b4
9aee162
 
 
 
 
 
 
 
 
 
 
21257b4
9aee162
21257b4
9aee162
 
 
 
 
 
21257b4
fdf7bd6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import gradio as gr
from config import AppConfig
from engine import FunctionGemmaEngine
from typing import Optional

def build_interface() -> gr.Blocks:
    
    # --- State Management Wrappers ---
    
    def init_session(profile: Optional[gr.OAuthProfile] = None):
        config = AppConfig()
        new_engine = FunctionGemmaEngine(config)

        username = profile.username if profile else None
        repo_update, push_update = update_hub_interactive(new_engine, username)

        return (
            new_engine,
            new_engine.get_tools_json(),
            new_engine.config.MODEL_NAME,
            f"Ready. (Session {new_engine.session_id})",
            repo_update, push_update, username
        )

    def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name):
        engine.config.MODEL_NAME = model_name.strip()
        yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)

    def handle_reset(engine, model_name):
        engine.config.MODEL_NAME = model_name.strip()
        return engine.refresh_model()

    def update_tools_wrapper(engine, json_val):
        return engine.update_tools(json_val)

    def import_file_wrapper(engine, file_obj):
        return engine.load_csv(file_obj)

    def stop_wrapper(engine):
        engine.trigger_stop()
        return "Stopping..."

    def zip_wrapper(engine):
        path = engine.get_zip_path()
        if path:
            return gr.update(value=path, visible=True)
        return gr.update(value=None, visible=False)

    def upload_wrapper(engine, repo_name, oauth_token: gr.OAuthToken | None):
        if oauth_token is None:
            return "❌ Error: You must log in (top right) to upload models."
        if not repo_name:
            return "❌ Error: Please enter a repository name."
        
        return engine.upload_model_to_hub(
            repo_name=repo_name, 
            oauth_token=oauth_token.token, 
        )

    def update_repo_preview(username, repo_name):
        """Updates the markdown preview to show 'username/repo_name'."""
        if not username:
            return "⚠️ Sign in to see the target repository path."
    
        clean_repo = repo_name.strip() if repo_name else "..."
        return f"Target Repository: **`{username}/{clean_repo}`**"
        
    def update_hub_interactive(engine, username: Optional[str] = None):
        is_logged_in = username is not None
        has_model_tuned = engine is not None and engine.has_model_tuned
        
        return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned)

    # --- UI Layout ---
    with gr.Blocks(title="FunctionGemma Modkit") as demo:
        engine_state = gr.State()
        username_state = gr.State()

        with gr.Column():
            gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
            gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
            gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).")
            with gr.Row():
                gr.LoginButton(value="Sign in with Hugging Face")
                with gr.Column(scale=3):
                    gr.Markdown("")
        
        with gr.Tabs():
            
            # --- TAB 1: PREPARING DATASET ---
            with gr.TabItem("1. Preparing Dataset"):
                gr.Markdown("### πŸ› οΈ Tool Schema & Data Import")
                
                with gr.Row():
                    with gr.Column(scale=1):
                        gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
                        tools_editor = gr.Code(
                            language="json", 
                            label="Tool Definitions (JSON Schema)",
                            lines=15
                        )
                        update_tools_btn = gr.Button("πŸ’Ύ Update Tool Schema")
                        tools_status = gr.Markdown("")

                    with gr.Column(scale=1):
                        gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).")
                        gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
                        import_file = gr.File(
                            label="Upload Dataset (.csv)", 
                            file_types=[".csv"], 
                            height=100
                        )
                        import_status = gr.Markdown("")

            # --- TAB 2: TRAINING ---
            with gr.TabItem("2. Training"):
                gr.Markdown("### πŸš€ Fine-Tuning Configuration")
                
                with gr.Group():
                    gr.Markdown("**Hyperparameters**")
                    with gr.Row():
                        default_models = AppConfig().AVAILABLE_MODELS
                        param_model = gr.Dropdown(
                            choices=default_models,
                            allow_custom_value=True, 
                            label="Base Model",
                            info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')",
                            interactive=True
                        )
                        param_epochs = gr.Slider(
                            minimum=1, maximum=20, value=5, step=1, 
                            label="Epochs", info="Total training passes"
                        )
                    with gr.Row():
                        param_lr = gr.Number(
                            value=5e-5, 
                            label="Learning Rate", 
                            info="e.g. 5e-5"
                        )
                        param_test_size = gr.Slider(
                            minimum=0.1, maximum=0.9, value=0.2, step=0.05, 
                            label="Test Split", info="Validation ratio (0.2 = 20%)"
                        )
                        param_shuffle = gr.Checkbox(
                            value=True,
                            label="Shuffle Data",
                            info="Randomize before split"
                        )

                with gr.Row():
                    run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=1)
                    stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
                    clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)

                with gr.Row():
                    output_display = gr.Textbox(
                        lines=20, 
                        label="Logs & Results", 
                        value="Initializing...", 
                        interactive=False,
                        autoscroll=True
                    )
                    loss_plot = gr.Plot(label="Training Metrics")

            # --- TAB 3: EXPORT ---
            with gr.TabItem("3. Export"):
                gr.Markdown("### πŸ“¦ Export Trained Model")
                
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("#### Option A: Download ZIP")
                        gr.Markdown("Download the model weights locally.")
                        zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
                        download_file = gr.File(label="Download Archive", interactive=False)
                    
                    with gr.Column():
                        gr.Markdown("#### Option B: Save to Hugging Face Hub")
                        gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
                        
                        repo_name_input = gr.Textbox(
                            label="TargetRepository Name", 
                            value="my-functiongemma-v1",
                            placeholder="e.g., my-functiongemma-v1",
                            interactive=False
                        )
                        push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
                        repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
                        
                        upload_status = gr.Markdown("")

        # --- EVENT WIRING ---
        
        action_buttons = [
            clear_reload_btn,
            run_training_btn,
            zip_btn
        ]

        def set_interactivity(interactive: bool):
            return [gr.update(interactive=interactive) for _ in action_buttons]

        demo.load(
            fn=lambda: set_interactivity(False), outputs=action_buttons
        ).then(
            fn=init_session,
            inputs=None,
            outputs=[engine_state, tools_editor, param_model, output_display, repo_name_input, push_to_hub_btn, username_state]
        ).then(
            fn=update_repo_preview,
            inputs=[username_state, repo_name_input],
            outputs=[repo_id_preview]
        ).then(
            fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
        )

        update_tools_btn.click(
            fn=update_tools_wrapper,
            inputs=[engine_state, tools_editor],
            outputs=[tools_status]
        )

        import_file.upload(
            fn=import_file_wrapper, 
            inputs=[engine_state, import_file], 
            outputs=[import_status]
        )
        
        run_training_btn.click(
            fn=lambda: (
                gr.update(visible=False), 
                gr.update(interactive=False), 
                gr.update(interactive=False), 
                gr.update(visible=True)
            ),
            outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
        ).then(
            fn=run_training_wrapper,
            inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model],
            outputs=[output_display, loss_plot],
        ).then(
            fn=lambda: (
                gr.update(visible=True), 
                gr.update(interactive=True), 
                gr.update(interactive=True), 
                gr.update(visible=False)
            ),
            outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
        ).then(
            fn=update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_name_input, push_to_hub_btn]
        )
        
        stop_training_btn.click(
            fn=stop_wrapper,
            inputs=[engine_state],
            outputs=None
        )

        clear_reload_btn.click(
            fn=lambda: set_interactivity(False), outputs=action_buttons
        ).then(
            fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
        ).then(
            fn=handle_reset, 
            inputs=[engine_state, param_model], 
            outputs=[output_display]
        ).then(
            fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
        ).then(
            fn=update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_name_input, push_to_hub_btn]
        )
        
        zip_btn.click(
            fn=lambda: set_interactivity(False), outputs=action_buttons
        ).then(
            fn=zip_wrapper,
            inputs=[engine_state],
            outputs=[download_file]
        ).then(
            fn=lambda: set_interactivity(True), outputs=action_buttons
        )
        
        repo_name_input.change(
            fn=update_repo_preview,
            inputs=[username_state, repo_name_input],
            outputs=[repo_id_preview]
        )

        push_to_hub_btn.click(
            fn=lambda: set_interactivity(False), outputs=action_buttons
        ).then(
            fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
        ).then(
            fn=upload_wrapper,
            inputs=[engine_state, repo_name_input],
            outputs=[upload_status]
        ).then(
            fn=lambda: set_interactivity(True), outputs=action_buttons
        ).then(
            fn=update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_name_input, push_to_hub_btn]
        )

    return demo