File size: 16,675 Bytes
fdf7bd6
0a2e0b5
21257b4
fdf7bd6
 
0a2e0b5
 
 
 
 
 
 
 
 
a6b7ef6
21257b4
 
9aee162
0a2e0b5
 
a6b7ef6
9aee162
21257b4
 
 
 
9aee162
a6b7ef6
 
 
0a2e0b5
21257b4
 
0a2e0b5
 
 
 
 
 
 
7e41311
 
 
c07c868
 
 
 
 
 
 
 
 
0a2e0b5
 
7e41311
9aee162
7e41311
0a2e0b5
 
21257b4
 
0a2e0b5
 
21257b4
 
0a2e0b5
 
21257b4
5d576bb
21257b4
0a2e0b5
 
21257b4
 
 
 
 
0a2e0b5
a6b7ef6
21257b4
 
a6b7ef6
 
21257b4
 
a6b7ef6
21257b4
 
 
0a2e0b5
a6b7ef6
 
 
 
 
9aee162
0a2e0b5
 
9aee162
0a2e0b5
9aee162
0a2e0b5
a6b7ef6
 
 
0a2e0b5
 
 
9aee162
0a2e0b5
 
a48dc21
0a2e0b5
9edc5b0
4947547
0a2e0b5
 
 
4947547
0a2e0b5
 
 
 
050e228
611e3a3
524f485
0a2e0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645faa6
0a2e0b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07c868
0a2e0b5
88fd2e2
0a2e0b5
 
 
 
 
 
 
 
645faa6
c07c868
0a2e0b5
 
 
 
 
 
 
 
 
 
 
 
 
c07c868
0a2e0b5
 
 
a6b7ef6
 
 
 
0a2e0b5
 
 
 
 
 
a6b7ef6
0a2e0b5
 
 
 
 
a48dc21
21257b4
9aee162
21257b4
0a2e0b5
fdf7bd6
 
0a2e0b5
 
 
 
 
a6b7ef6
 
645faa6
 
0a2e0b5
a6b7ef6
 
0a2e0b5
 
 
 
 
a6b7ef6
9aee162
0a2e0b5
 
 
9aee162
0a2e0b5
fdf7bd6
0a2e0b5
a6b7ef6
0a2e0b5
a6b7ef6
0a2e0b5
 
 
 
 
a6b7ef6
0a2e0b5
a6b7ef6
0a2e0b5
 
9aee162
0a2e0b5
a6b7ef6
 
0a2e0b5
21257b4
0a2e0b5
 
 
 
 
fdf7bd6
 
0a2e0b5
 
 
 
fdf7bd6
0a2e0b5
645faa6
9aee162
645faa6
4a1ace6
a6b7ef6
 
 
 
 
 
 
 
4a1ace6
 
0a2e0b5
 
 
fdf7bd6
 
21257b4
 
645faa6
21257b4
fdf7bd6
645faa6
9aee162
0a2e0b5
 
9aee162
a6b7ef6
fdf7bd6
 
645faa6
4a1ace6
645faa6
 
 
a6b7ef6
645faa6
 
 
4a1ace6
 
645faa6
 
 
 
 
 
 
 
 
 
 
 
 
4a1ace6
 
 
5d576bb
 
 
4a1ace6
0a2e0b5
a6b7ef6
0a2e0b5
 
 
 
 
9aee162
a6b7ef6
fdf7bd6
0a2e0b5
 
a6b7ef6
0a2e0b5
21257b4
0a2e0b5
 
 
1cc09c6
a6b7ef6
fdf7bd6
0a2e0b5
a6b7ef6
0a2e0b5
a6b7ef6
 
9aee162
 
a6b7ef6
0a2e0b5
a6b7ef6
 
0a2e0b5
 
9aee162
a6b7ef6
21257b4
fdf7bd6
645faa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import gradio as gr
from typing import Optional, Tuple, Generator, List, Any
from config import AppConfig
from engine import FunctionGemmaEngine

# --- Controller / Logic Layer ---

class UIController:
    """
    Handles the business logic and interaction with the Engine.
    Stateless methods that operate on the passed Engine state.
    """

    @staticmethod
    def init_session(profile: Optional[gr.OAuthProfile] = None) -> Tuple[Any, ...]:
        config = AppConfig()
        new_engine = FunctionGemmaEngine(config)
        username = profile.username if profile else None
        
        # Calculate initial interactivity state
        repo_update, push_update, zip_update = UIController.update_hub_interactive(new_engine, username)

        return (
            new_engine,
            new_engine.get_tools_json(),
            new_engine.config.MODEL_NAME,
            f"Ready. (Session {new_engine.session_id})",
            repo_update, 
            push_update,
            zip_update, 
            username
        )

    @staticmethod
    def run_training(engine: FunctionGemmaEngine, epochs: int, lr: float, 
                     test_size: float, shuffle: bool, model_name: str) -> Generator:
        if not engine:
            yield "⚠️ Engine not initialized.", None
            return
            
        engine.config.MODEL_NAME = model_name.strip()
        yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)

    @staticmethod
    def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator:
        if not engine:
            yield "⚠️ Engine not initialized."
            return
        
        engine.config.MODEL_NAME = model_name.strip()
        yield from engine.run_evaluation(test_size, shuffle)

    @staticmethod
    def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str:
        engine.config.MODEL_NAME = model_name.strip()
        return engine.refresh_model()

    @staticmethod
    def update_tools(engine: FunctionGemmaEngine, json_val: str) -> str:
        return engine.update_tools(json_val)

    @staticmethod
    def import_file(engine: FunctionGemmaEngine, file_obj: Any) -> str:
        return engine.load_csv(file_obj)

    @staticmethod
    def stop_process(engine: FunctionGemmaEngine) -> str:
        engine.trigger_stop()
        return

    @staticmethod
    def zip_model(engine: FunctionGemmaEngine) -> Any:
        path = engine.get_zip_path()
        if path:
            return gr.update(value=path, visible=True)
        return gr.update(value=None, visible=False)

    @staticmethod
    def upload_model(engine: FunctionGemmaEngine, repo_name: str, oauth_token: Optional[gr.OAuthToken]) -> str:
        if oauth_token is None:
            return "❌ Error: You must log in (top right) to upload models."
        if not repo_name:
            return "❌ Error: Please enter a repository name."
        
        return engine.upload_model_to_hub(
            repo_name=repo_name, 
            oauth_token=oauth_token.token, 
        )

    @staticmethod
    def update_repo_preview(username: Optional[str], repo_name: str) -> str:
        if not username:
            return "⚠️ Sign in to see the target repository path."
        clean_repo = repo_name.strip() if repo_name else "..."
        return f"Target Repository: **`{username}/{clean_repo}`**"
        
    @staticmethod
    def update_hub_interactive(engine: Optional[FunctionGemmaEngine], username: Optional[str] = None):
        is_logged_in = username is not None
        has_model_tuned = engine is not None and getattr(engine, 'has_model_tuned', False)
        
        return (
            gr.update(interactive=is_logged_in), 
            gr.update(interactive=is_logged_in and has_model_tuned),
            gr.update(interactive=has_model_tuned)
        )

# --- View / Layout Layer ---

def _render_header():
    with gr.Column():
        gr.Markdown("# πŸ€– FunctionGemma Tuning Lab: Fine-Tuning")
        gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>"
                    "See [README](https://huggingface.co/spaces/google/functiongemma-tuning-lab/blob/main/README.md) for more details.")
        gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).<br>⚠️ **Warning:** Signing in will refresh the page and reset your current session (including data and model progress).")
        with gr.Row():
            gr.LoginButton(value="Sign in with Hugging Face")
            with gr.Column(scale=3):
                gr.Markdown("")

def _render_dataset_tab(engine_state):
    with gr.TabItem("1. Preparing Dataset"):
        gr.Markdown("### πŸ› οΈ Tool Schema & Data Import")
        gr.Markdown("**Important Limitation:** This configuration will fail if the defined tools require **different parameter structures**.<br>The framework cannot currently handle a mix of tools with distinct signatures. For example, the following combination will not work:")
        gr.Markdown("* `sum(int a, int b)`\n* `query(string q)`")
        gr.Markdown("Ensure that all tools within this specific schema definition share a consistent parameter format.")
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
                tools_editor = gr.Code(language="json", label="Tool Definitions (JSON Schema)", lines=15)
                update_tools_btn = gr.Button("πŸ’Ύ Update Tool Schema")
                tools_status = gr.Markdown("")

            with gr.Column(scale=1):
                gr.Markdown("**Step 2: Upload Data (Optional)**<br>To train on your own data, upload a CSV file to replace the [default dataset](https://huggingface.co/datasets/bebechien/SimpleToolCalling).")
                gr.Markdown("**Example CSV Row:** No header required.<br>Format: `[User Prompt, Tool Name, Tool Args JSON]`\n```csv\n\"What is the weather in London?\", \"get_weather\", \"{\"\"location\"\": \"\"London, UK\"\"}\"\n```")
                import_file = gr.File(label="Upload Dataset (.csv)", file_types=[".csv"], height=100)
                import_status = gr.Markdown("")
    
    # Return controls needed for wiring
    return {
        "tools_editor": tools_editor,
        "update_tools_btn": update_tools_btn,
        "tools_status": tools_status,
        "import_file": import_file,
        "import_status": import_status
    }

def _render_training_tab(engine_state):
    with gr.TabItem("2. Training & Eval"):
        gr.Markdown("### πŸš€ Fine-Tuning Configuration")
        with gr.Group():
            gr.Markdown("**Hyperparameters**")
            with gr.Row():
                default_models = AppConfig().AVAILABLE_MODELS
                param_model = gr.Dropdown(
                    choices=default_models, allow_custom_value=True, label="Base Model", info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')", interactive=True
                )
                param_epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs", info="Total training passes")
            with gr.Row():
                param_lr = gr.Number(value=5e-5, label="Learning Rate", info="e.g. 5e-5")
                param_test_size = gr.Slider(0.1, 0.9, value=0.2, step=0.05, label="Test Split", info="Validation ratio (0.2 = 20%)")
                param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split")

        with gr.Row():
            run_eval_btn = gr.Button("πŸ§ͺ Run Evaluation", variant="secondary", scale=1)
            stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
            run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=1)
            clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)

        with gr.Row():
            output_display = gr.Textbox(lines=20, label="Logs", value="Initializing...", interactive=False, autoscroll=True)
            loss_plot = gr.Plot(label="Training Metrics")

    return {
        "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model],
        "eval_params": [param_test_size, param_shuffle, param_model],
        "buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn],
        "outputs": [output_display, loss_plot],
        "model_input": param_model # specifically needed for initialization
    }

def _render_export_tab(engine_state, username_state):
    with gr.TabItem("3. Export"):
        gr.Markdown("### πŸ“¦ Export Trained Model")
        with gr.Row():
            with gr.Column():
                gr.Markdown("#### Option A: Download ZIP")
                gr.Markdown("Download the model weights locally.")
                zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
                download_file = gr.File(label="Download Archive", interactive=False)
                gr.Markdown("NOTE: Zipping usually takes 1~2 min.")
            
            with gr.Column():
                gr.Markdown("#### Option B: Save to Hugging Face Hub")
                gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
                repo_name_input = gr.Textbox(
                    label="Target Repository Name", value="functiongemma-270m-it-tuning-lab", placeholder="e.g., functiongemma-270m-it-tuned", interactive=False
                )
                push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
                repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
                upload_status = gr.Markdown("")

    return {
        "zip_controls": [zip_btn, download_file],
        "hub_controls": [repo_name_input, push_to_hub_btn, repo_id_preview, upload_status]
    }

# --- Main Build Function ---

def build_interface() -> gr.Blocks:
    with gr.Blocks(title="FunctionGemma Tuning Lab") as demo:
        engine_state = gr.State()
        username_state = gr.State()

        _render_header()
        
        with gr.Tabs():
            data_ui = _render_dataset_tab(engine_state)
            train_ui = _render_training_tab(engine_state)
            export_ui = _render_export_tab(engine_state, username_state)

        # Helpers for UI State
        # 'action_buttons' now ONLY contains buttons that should always be enabled after a process
        # Zip and Push buttons are excluded here because their state depends on has_model_tuned
        run_btn, stop_btn, reload_btn, eval_btn = train_ui["buttons"]
        action_buttons = [reload_btn, run_btn, eval_btn] 
        
        repo_input = export_ui["hub_controls"][0]
        push_btn = export_ui["hub_controls"][1]
        zip_btn = export_ui["zip_controls"][0]

        def lock_ui():
            """Locks all buttons (including Zip/Push) during processing"""
            return [gr.update(interactive=False) for _ in action_buttons] + \
                   [gr.update(interactive=False), gr.update(interactive=False)]
        
        def unlock_ui():
            """Unlocks general action buttons only. Zip/Push are handled by update_hub_interactive"""
            return [gr.update(interactive=True) for _ in action_buttons]

        # --- Event Wiring ---

        # 1. Initialization
        demo.load(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
            fn=UIController.init_session,
            inputs=None,
            outputs=[
                engine_state, 
                data_ui["tools_editor"], 
                train_ui["model_input"], 
                train_ui["outputs"][0], # log output
                repo_input, 
                push_btn, 
                zip_btn,                # Update Zip state based on initial engine state
                username_state
            ]
        ).then(
            fn=UIController.update_repo_preview,
            inputs=[username_state, repo_input],
            outputs=[export_ui["hub_controls"][2]] 
        ).then(unlock_ui, outputs=action_buttons)

        # 2. Data Tab
        data_ui["update_tools_btn"].click(
            fn=UIController.update_tools,
            inputs=[engine_state, data_ui["tools_editor"]],
            outputs=[data_ui["tools_status"]]
        )

        data_ui["import_file"].upload(
            fn=UIController.import_file,
            inputs=[engine_state, data_ui["import_file"]],
            outputs=[data_ui["import_status"]]
        )

        # 3. Training & Eval Tab
        
        # 3a. Training
        train_run_event = run_btn.click(
            fn=lambda: (
                gr.update(visible=False), 
                gr.update(interactive=False), # Lock Reload
                gr.update(interactive=False), # Lock Eval
                gr.update(interactive=False), # Lock Zip 
                gr.update(visible=True)       # Show Stop
            ),
            outputs=[run_btn, reload_btn, eval_btn, zip_btn, stop_btn]
        )
        train_run_event = train_run_event.then(
            fn=UIController.run_training,
            inputs=[engine_state, *train_ui["params"]],
            outputs=train_ui["outputs"],
        ).then(
            fn=lambda: (
                gr.update(visible=True), 
                gr.update(interactive=True), 
                gr.update(interactive=True),
                gr.update(visible=False)
            ),
            outputs=[run_btn, reload_btn, eval_btn, stop_btn]
        ).then(
            # Final check determines if Zip/Push should unlock
            fn=UIController.update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_input, push_btn, zip_btn] 
        )

        # 3b. Evaluation
        eval_run_event = eval_btn.click(
            fn=lambda: (
                gr.update(interactive=False), # Lock Run
                gr.update(interactive=False), # Lock Reload
                gr.update(visible=False),     # Hide self (optional, or lock)
                gr.update(visible=True)       # Show Stop
            ),
            outputs=[run_btn, reload_btn, eval_btn, stop_btn]
        )
        eval_run_event = eval_run_event.then(
            fn=UIController.run_evaluation,
            inputs=[engine_state, *train_ui["eval_params"]],
            outputs=[train_ui["outputs"][0]] # Output only to log, not plot
        ).then(
            fn=lambda: (
                gr.update(interactive=True), 
                gr.update(interactive=True), 
                gr.update(visible=True),
                gr.update(visible=False)
            ),
            outputs=[run_btn, reload_btn, eval_btn, stop_btn]
        )

        stop_btn.click(
            fn=UIController.stop_process,
            inputs=[engine_state],
            cancels=[train_run_event, eval_run_event],
            outputs=None,
            queue=False
        )

        reload_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
            fn=UIController.handle_reset,
            inputs=[engine_state, train_ui["model_input"]],
            outputs=[train_ui["outputs"][0]]
        ).then(unlock_ui, outputs=action_buttons).then(
            fn=UIController.update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_input, push_btn, zip_btn]
        )

        # 4. Export Tab
        zip_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
            fn=UIController.zip_model,
            inputs=[engine_state],
            outputs=[export_ui["zip_controls"][1]]
        ).then(unlock_ui, outputs=action_buttons).then(
            fn=UIController.update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_input, push_btn, zip_btn]
        )

        repo_input.change(
            fn=UIController.update_repo_preview,
            inputs=[username_state, repo_input],
            outputs=[export_ui["hub_controls"][2]]
        )

        push_btn.click(lock_ui, outputs=action_buttons + [push_btn, zip_btn]).then(
            fn=UIController.upload_model,
            inputs=[engine_state, repo_input],
            outputs=[export_ui["hub_controls"][3]]
        ).then(unlock_ui, outputs=action_buttons).then(
            fn=UIController.update_hub_interactive,
            inputs=[engine_state, username_state],
            outputs=[repo_input, push_btn, zip_btn]
        )

    return demo