bebechien commited on
Commit
9aee162
·
verified ·
1 Parent(s): 43c4111

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.py +3 -4
  2. engine.py +6 -3
  3. ui.py +100 -21
config.py CHANGED
@@ -17,12 +17,11 @@ class AppConfig:
17
 
18
  # Model Configuration
19
  # Mutable: User can change this in the UI
20
- MODEL_NAME: str = '../hf/270m'
21
 
22
  AVAILABLE_MODELS: List[str] = field(default_factory=lambda: [
23
- '../hf/270m',
24
- '../hf/gemma-3-270m-it',
25
- 'google/gemma-3-270m-it'
26
  ])
27
 
28
  DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
 
17
 
18
  # Model Configuration
19
  # Mutable: User can change this in the UI
20
+ MODEL_NAME: str = 'gg-hf-gm/functiongemma-270m-it'
21
 
22
  AVAILABLE_MODELS: List[str] = field(default_factory=lambda: [
23
+ 'gg-hf-gm/functiongemma-270m-it',
24
+ 'google/functiongemma-270m-it'
 
25
  ])
26
 
27
  DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
engine.py CHANGED
@@ -80,10 +80,11 @@ class FunctionGemmaEngine:
80
  self.imported_dataset = []
81
  self.stop_event = threading.Event()
82
  self.current_tools = DEFAULT_TOOLS
 
83
 
84
  authenticate_hf(self.config.HF_TOKEN)
85
  try:
86
- self.refresh_data_and_model()
87
  except Exception as e:
88
  print(f"Initial load warning: {e}")
89
 
@@ -110,8 +111,8 @@ class FunctionGemmaEngine:
110
  self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
111
  self.loaded_model_name = self.config.MODEL_NAME
112
 
113
- def refresh_data_and_model(self) -> str:
114
- self.imported_dataset = []
115
  try:
116
  self._load_model_weights()
117
  return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady (Session {self.session_id})."
@@ -230,6 +231,8 @@ class FunctionGemmaEngine:
230
 
231
  train_thread.join()
232
 
 
 
233
  while not log_queue.empty():
234
  payload = log_queue.get()
235
  if isinstance(payload, tuple):
 
80
  self.imported_dataset = []
81
  self.stop_event = threading.Event()
82
  self.current_tools = DEFAULT_TOOLS
83
+ self.has_model_tuned = False
84
 
85
  authenticate_hf(self.config.HF_TOKEN)
86
  try:
87
+ self.refresh_model()
88
  except Exception as e:
89
  print(f"Initial load warning: {e}")
90
 
 
111
  self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
112
  self.loaded_model_name = self.config.MODEL_NAME
113
 
114
+ def refresh_model(self) -> str:
115
+ self.has_model_tuned = False
116
  try:
117
  self._load_model_weights()
118
  return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady (Session {self.session_id})."
 
231
 
232
  train_thread.join()
233
 
234
+ self.has_model_tuned = True
235
+
236
  while not log_queue.empty():
237
  payload = log_queue.get()
238
  if isinstance(payload, tuple):
ui.py CHANGED
@@ -1,19 +1,25 @@
1
  import gradio as gr
2
  from config import AppConfig
3
  from engine import FunctionGemmaEngine
 
4
 
5
  def build_interface() -> gr.Blocks:
6
 
7
  # --- State Management Wrappers ---
8
 
9
- def init_session():
10
  config = AppConfig()
11
  new_engine = FunctionGemmaEngine(config)
 
 
 
 
12
  return (
13
  new_engine,
14
  new_engine.get_tools_json(),
15
  new_engine.config.MODEL_NAME,
16
- f"Ready. (Session {new_engine.session_id})"
 
17
  )
18
 
19
  def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name):
@@ -22,7 +28,7 @@ def build_interface() -> gr.Blocks:
22
 
23
  def handle_reset(engine, model_name):
24
  engine.config.MODEL_NAME = model_name.strip()
25
- return engine.refresh_data_and_model()
26
 
27
  def update_tools_wrapper(engine, json_val):
28
  return engine.update_tools(json_val)
@@ -51,14 +57,33 @@ def build_interface() -> gr.Blocks:
51
  oauth_token=oauth_token.token,
52
  )
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # --- UI Layout ---
55
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
56
  engine_state = gr.State()
 
57
 
58
  with gr.Column():
59
  gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning")
60
  gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
61
- gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
 
 
 
 
62
 
63
  with gr.Tabs():
64
 
@@ -99,7 +124,7 @@ def build_interface() -> gr.Blocks:
99
  choices=default_models,
100
  allow_custom_value=True,
101
  label="Base Model",
102
- info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/gemma-3-1b-it')",
103
  interactive=True
104
  )
105
  param_epochs = gr.Slider(
@@ -145,29 +170,47 @@ def build_interface() -> gr.Blocks:
145
  with gr.Column():
146
  gr.Markdown("#### Option A: Download ZIP")
147
  gr.Markdown("Download the model weights locally.")
148
- zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary")
149
  download_file = gr.File(label="Download Archive", interactive=False)
150
 
151
  with gr.Column():
152
- gr.Markdown("#### Option B: Upload to Hugging Face Hub")
153
- gr.Markdown("Publish to your HF profile. **Requires Login**.")
154
 
155
- with gr.Group():
156
- repo_id_input = gr.Textbox(
157
- label="Repository Name",
158
- placeholder="my-function-gemma-v1",
159
- info="Will be created under your username (e.g. user/repo)"
160
- )
161
- upload_hub_btn = gr.Button("☁️ Upload to Hub", variant="primary")
 
162
 
163
  upload_status = gr.Markdown("")
164
 
165
  # --- EVENT WIRING ---
 
 
 
 
 
 
 
 
 
166
 
167
  demo.load(
 
 
168
  fn=init_session,
169
  inputs=None,
170
- outputs=[engine_state, tools_editor, param_model, output_display]
 
 
 
 
 
 
171
  )
172
 
173
  update_tools_btn.click(
@@ -181,14 +224,15 @@ def build_interface() -> gr.Blocks:
181
  inputs=[engine_state, import_file],
182
  outputs=[import_status]
183
  )
184
-
185
  run_training_btn.click(
186
  fn=lambda: (
187
  gr.update(visible=False),
188
  gr.update(interactive=False),
 
189
  gr.update(visible=True)
190
  ),
191
- outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
192
  ).then(
193
  fn=run_training_wrapper,
194
  inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model],
@@ -197,9 +241,14 @@ def build_interface() -> gr.Blocks:
197
  fn=lambda: (
198
  gr.update(visible=True),
199
  gr.update(interactive=True),
 
200
  gr.update(visible=False)
201
  ),
202
- outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
 
 
 
 
203
  )
204
 
205
  stop_training_btn.click(
@@ -209,21 +258,51 @@ def build_interface() -> gr.Blocks:
209
  )
210
 
211
  clear_reload_btn.click(
 
 
 
 
212
  fn=handle_reset,
213
  inputs=[engine_state, param_model],
214
  outputs=[output_display]
 
 
 
 
 
 
215
  )
216
 
217
  zip_btn.click(
 
 
218
  fn=zip_wrapper,
219
  inputs=[engine_state],
220
  outputs=[download_file]
 
 
221
  )
222
 
223
- upload_hub_btn.click(
 
 
 
 
 
 
 
 
 
 
224
  fn=upload_wrapper,
225
- inputs=[engine_state, repo_id_input],
226
  outputs=[upload_status]
 
 
 
 
 
 
227
  )
228
 
229
  return demo
 
1
  import gradio as gr
2
  from config import AppConfig
3
  from engine import FunctionGemmaEngine
4
+ from typing import Optional
5
 
6
  def build_interface() -> gr.Blocks:
7
 
8
  # --- State Management Wrappers ---
9
 
10
+ def init_session(profile: Optional[gr.OAuthProfile] = None):
11
  config = AppConfig()
12
  new_engine = FunctionGemmaEngine(config)
13
+
14
+ username = profile.username if profile else None
15
+ repo_update, push_update = update_hub_interactive(new_engine, username)
16
+
17
  return (
18
  new_engine,
19
  new_engine.get_tools_json(),
20
  new_engine.config.MODEL_NAME,
21
+ f"Ready. (Session {new_engine.session_id})",
22
+ repo_update, push_update, username
23
  )
24
 
25
  def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name):
 
28
 
29
  def handle_reset(engine, model_name):
30
  engine.config.MODEL_NAME = model_name.strip()
31
+ return engine.refresh_model()
32
 
33
  def update_tools_wrapper(engine, json_val):
34
  return engine.update_tools(json_val)
 
57
  oauth_token=oauth_token.token,
58
  )
59
 
60
+ def update_repo_preview(username, repo_name):
61
+ """Updates the markdown preview to show 'username/repo_name'."""
62
+ if not username:
63
+ return "⚠️ Sign in to see the target repository path."
64
+
65
+ clean_repo = repo_name.strip() if repo_name else "..."
66
+ return f"Target Repository: **`{username}/{clean_repo}`**"
67
+
68
+ def update_hub_interactive(engine, username: Optional[str] = None):
69
+ is_logged_in = username is not None
70
+ has_model_tuned = engine is not None and engine.has_model_tuned
71
+
72
+ return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned)
73
+
74
  # --- UI Layout ---
75
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
76
  engine_state = gr.State()
77
+ username_state = gr.State()
78
 
79
  with gr.Column():
80
  gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning")
81
  gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
82
+ gr.Markdown("(Optional) Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (3. Export).")
83
+ with gr.Row():
84
+ gr.LoginButton(value="Sign in with Hugging Face")
85
+ with gr.Column(scale=3):
86
+ gr.Markdown("")
87
 
88
  with gr.Tabs():
89
 
 
124
  choices=default_models,
125
  allow_custom_value=True,
126
  label="Base Model",
127
+ info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/functiongemma-270m-it')",
128
  interactive=True
129
  )
130
  param_epochs = gr.Slider(
 
170
  with gr.Column():
171
  gr.Markdown("#### Option A: Download ZIP")
172
  gr.Markdown("Download the model weights locally.")
173
+ zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
174
  download_file = gr.File(label="Download Archive", interactive=False)
175
 
176
  with gr.Column():
177
+ gr.Markdown("#### Option B: Save to Hugging Face Hub")
178
+ gr.Markdown("Publish your fine-tuned model to your personal Hugging Face account.")
179
 
180
+ repo_name_input = gr.Textbox(
181
+ label="TargetRepository Name",
182
+ value="my-functiongemma-v1",
183
+ placeholder="e.g., my-functiongemma-v1",
184
+ interactive=False
185
+ )
186
+ push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
187
+ repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
188
 
189
  upload_status = gr.Markdown("")
190
 
191
  # --- EVENT WIRING ---
192
+
193
+ action_buttons = [
194
+ clear_reload_btn,
195
+ run_training_btn,
196
+ zip_btn
197
+ ]
198
+
199
+ def set_interactivity(interactive: bool):
200
+ return [gr.update(interactive=interactive) for _ in action_buttons]
201
 
202
  demo.load(
203
+ fn=lambda: set_interactivity(False), outputs=action_buttons
204
+ ).then(
205
  fn=init_session,
206
  inputs=None,
207
+ outputs=[engine_state, tools_editor, param_model, output_display, repo_name_input, push_to_hub_btn, username_state]
208
+ ).then(
209
+ fn=update_repo_preview,
210
+ inputs=[username_state, repo_name_input],
211
+ outputs=[repo_id_preview]
212
+ ).then(
213
+ fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
214
  )
215
 
216
  update_tools_btn.click(
 
224
  inputs=[engine_state, import_file],
225
  outputs=[import_status]
226
  )
227
+
228
  run_training_btn.click(
229
  fn=lambda: (
230
  gr.update(visible=False),
231
  gr.update(interactive=False),
232
+ gr.update(interactive=False),
233
  gr.update(visible=True)
234
  ),
235
+ outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
236
  ).then(
237
  fn=run_training_wrapper,
238
  inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model],
 
241
  fn=lambda: (
242
  gr.update(visible=True),
243
  gr.update(interactive=True),
244
+ gr.update(interactive=True),
245
  gr.update(visible=False)
246
  ),
247
+ outputs=[run_training_btn, clear_reload_btn, zip_btn, stop_training_btn]
248
+ ).then(
249
+ fn=update_hub_interactive,
250
+ inputs=[engine_state, username_state],
251
+ outputs=[repo_name_input, push_to_hub_btn]
252
  )
253
 
254
  stop_training_btn.click(
 
258
  )
259
 
260
  clear_reload_btn.click(
261
+ fn=lambda: set_interactivity(False), outputs=action_buttons
262
+ ).then(
263
+ fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
264
+ ).then(
265
  fn=handle_reset,
266
  inputs=[engine_state, param_model],
267
  outputs=[output_display]
268
+ ).then(
269
+ fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
270
+ ).then(
271
+ fn=update_hub_interactive,
272
+ inputs=[engine_state, username_state],
273
+ outputs=[repo_name_input, push_to_hub_btn]
274
  )
275
 
276
  zip_btn.click(
277
+ fn=lambda: set_interactivity(False), outputs=action_buttons
278
+ ).then(
279
  fn=zip_wrapper,
280
  inputs=[engine_state],
281
  outputs=[download_file]
282
+ ).then(
283
+ fn=lambda: set_interactivity(True), outputs=action_buttons
284
  )
285
 
286
+ repo_name_input.change(
287
+ fn=update_repo_preview,
288
+ inputs=[username_state, repo_name_input],
289
+ outputs=[repo_id_preview]
290
+ )
291
+
292
+ push_to_hub_btn.click(
293
+ fn=lambda: set_interactivity(False), outputs=action_buttons
294
+ ).then(
295
+ fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
296
+ ).then(
297
  fn=upload_wrapper,
298
+ inputs=[engine_state, repo_name_input],
299
  outputs=[upload_status]
300
+ ).then(
301
+ fn=lambda: set_interactivity(True), outputs=action_buttons
302
+ ).then(
303
+ fn=update_hub_interactive,
304
+ inputs=[engine_state, username_state],
305
+ outputs=[repo_name_input, push_to_hub_btn]
306
  )
307
 
308
  return demo