bebechien commited on
Commit
21257b4
Β·
verified Β·
1 Parent(s): 7e41311

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +4 -10
  2. engine.py +39 -18
  3. requirements.txt +1 -0
  4. ui.py +101 -44
app.py CHANGED
@@ -1,15 +1,9 @@
1
- from config import AppConfig
2
- from engine import FunctionGemmaEngine
3
  from ui import build_interface
4
 
5
  if __name__ == "__main__":
6
- # Initialize Config
7
- config = AppConfig()
8
-
9
- # Initialize Logic Engine
10
- app_engine = FunctionGemmaEngine(config)
11
-
12
  # Build and Launch UI
13
- demo = build_interface(app_engine)
14
- print("Starting Gradio App...")
 
 
15
  demo.launch()
 
 
 
1
  from ui import build_interface
2
 
3
  if __name__ == "__main__":
 
 
 
 
 
 
4
  # Build and Launch UI
5
+ # Note: Engine creation is now handled per-session inside build_interface
6
+ demo = build_interface()
7
+ print("Starting Gradio App with Multi-User Support...")
8
+ demo.queue() # Enable queueing for concurrent request handling
9
  demo.launch()
engine.py CHANGED
@@ -3,12 +3,14 @@ import torch
3
  import time
4
  import json
5
  import queue
 
6
  import matplotlib.pyplot as plt
7
  from functools import partial
8
  from typing import Generator, Optional, List, Dict, Any, Tuple
9
  from datasets import Dataset, load_dataset
10
  from trl import SFTConfig, SFTTrainer
11
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
 
12
 
13
  from config import AppConfig
14
  from tools import DEFAULT_TOOLS
@@ -29,10 +31,6 @@ class AbortCallback(TrainerCallback):
29
  control.should_training_stop = True
30
 
31
  class LogStreamingCallback(TrainerCallback):
32
- """
33
- Intercepts training logs and pushes them to a queue.
34
- Sends tuple: (formatted_string, raw_dict)
35
- """
36
  def __init__(self, log_queue: queue.Queue):
37
  self.log_queue = log_queue
38
 
@@ -63,7 +61,6 @@ class LogStreamingCallback(TrainerCallback):
63
 
64
  log_parts.append(f"{label}: {val_str}")
65
 
66
- # Structure for plotting
67
  log_payload = logs.copy()
68
  log_payload['step'] = state.global_step
69
 
@@ -72,6 +69,11 @@ class LogStreamingCallback(TrainerCallback):
72
  class FunctionGemmaEngine:
73
  def __init__(self, config: AppConfig):
74
  self.config = config
 
 
 
 
 
75
  self.model = None
76
  self.tokenizer = None
77
  self.loaded_model_name = None
@@ -104,17 +106,15 @@ class FunctionGemmaEngine:
104
  # --- Model & Data Management ---
105
 
106
  def _load_model_weights(self):
107
- """Internal helper to load model based on current config."""
108
- print(f"Loading model: {self.config.MODEL_NAME}...")
109
  self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
110
  self.loaded_model_name = self.config.MODEL_NAME
111
 
112
  def refresh_data_and_model(self) -> str:
113
- """Full reset: Reloads model and clears dataset."""
114
  self.imported_dataset = []
115
  try:
116
  self._load_model_weights()
117
- return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady."
118
  except Exception as e:
119
  self.model = None
120
  self.tokenizer = None
@@ -141,7 +141,6 @@ class FunctionGemmaEngine:
141
  output_buffer = ""
142
  last_plot = None
143
 
144
- # 1. Check if model name changed since last load
145
  if self.config.MODEL_NAME != self.loaded_model_name:
146
  output_buffer += f"πŸ”„ Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n"
147
  yield output_buffer, None
@@ -231,7 +230,6 @@ class FunctionGemmaEngine:
231
 
232
  train_thread.join()
233
 
234
- # Flush logs
235
  while not log_queue.empty():
236
  payload = log_queue.get()
237
  if isinstance(payload, tuple):
@@ -287,7 +285,7 @@ class FunctionGemmaEngine:
287
  def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]:
288
  torch_dtype = self.model.dtype
289
  args = SFTConfig(
290
- output_dir=str(self.config.OUTPUT_DIR),
291
  max_length=512,
292
  packing=False,
293
  num_train_epochs=epochs,
@@ -319,8 +317,6 @@ class FunctionGemmaEngine:
319
 
320
  def _generate_loss_plot(self, history: list):
321
  if not history: return None
322
-
323
- # CHANGED: Close previous figures to prevent memory warning
324
  plt.close('all')
325
 
326
  train_steps = [x['step'] for x in history if 'loss' in x]
@@ -372,7 +368,32 @@ class FunctionGemmaEngine:
372
  yield f"Error during inference: {e}"
373
 
374
  def get_zip_path(self) -> Optional[str]:
375
- if not self.config.OUTPUT_DIR.exists(): return None
376
- timestamp = int(time.time())
377
- base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}"))
378
- return zip_directory(str(self.config.OUTPUT_DIR), base_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import time
4
  import json
5
  import queue
6
+ import uuid
7
  import matplotlib.pyplot as plt
8
  from functools import partial
9
  from typing import Generator, Optional, List, Dict, Any, Tuple
10
  from datasets import Dataset, load_dataset
11
  from trl import SFTConfig, SFTTrainer
12
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
13
+ from huggingface_hub import HfApi # Added for Hub Upload
14
 
15
  from config import AppConfig
16
  from tools import DEFAULT_TOOLS
 
31
  control.should_training_stop = True
32
 
33
  class LogStreamingCallback(TrainerCallback):
 
 
 
 
34
  def __init__(self, log_queue: queue.Queue):
35
  self.log_queue = log_queue
36
 
 
61
 
62
  log_parts.append(f"{label}: {val_str}")
63
 
 
64
  log_payload = logs.copy()
65
  log_payload['step'] = state.global_step
66
 
 
69
  class FunctionGemmaEngine:
70
  def __init__(self, config: AppConfig):
71
  self.config = config
72
+
73
+ self.session_id = str(uuid.uuid4())[:8]
74
+ self.output_dir = self.config.ARTIFACTS_DIR.joinpath(f"session_{self.session_id}")
75
+ self.output_dir.mkdir(parents=True, exist_ok=True)
76
+
77
  self.model = None
78
  self.tokenizer = None
79
  self.loaded_model_name = None
 
106
  # --- Model & Data Management ---
107
 
108
  def _load_model_weights(self):
109
+ print(f"[{self.session_id}] Loading model: {self.config.MODEL_NAME}...")
 
110
  self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
111
  self.loaded_model_name = self.config.MODEL_NAME
112
 
113
  def refresh_data_and_model(self) -> str:
 
114
  self.imported_dataset = []
115
  try:
116
  self._load_model_weights()
117
+ return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady (Session {self.session_id})."
118
  except Exception as e:
119
  self.model = None
120
  self.tokenizer = None
 
141
  output_buffer = ""
142
  last_plot = None
143
 
 
144
  if self.config.MODEL_NAME != self.loaded_model_name:
145
  output_buffer += f"πŸ”„ Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n"
146
  yield output_buffer, None
 
230
 
231
  train_thread.join()
232
 
 
233
  while not log_queue.empty():
234
  payload = log_queue.get()
235
  if isinstance(payload, tuple):
 
285
  def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]:
286
  torch_dtype = self.model.dtype
287
  args = SFTConfig(
288
+ output_dir=str(self.output_dir),
289
  max_length=512,
290
  packing=False,
291
  num_train_epochs=epochs,
 
317
 
318
  def _generate_loss_plot(self, history: list):
319
  if not history: return None
 
 
320
  plt.close('all')
321
 
322
  train_steps = [x['step'] for x in history if 'loss' in x]
 
368
  yield f"Error during inference: {e}"
369
 
370
  def get_zip_path(self) -> Optional[str]:
371
+ if not self.output_dir.exists(): return None
372
+ base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{self.session_id}"))
373
+ return zip_directory(str(self.output_dir), base_name)
374
+
375
+ def upload_model_to_hub(self, repo_name: str, oauth_token: str) -> str:
376
+ """Uploads the trained model to Hugging Face Hub."""
377
+ if not self.output_dir.exists() or not any(self.output_dir.iterdir()):
378
+ return "❌ No trained model found in current session. Run training first."
379
+
380
+ try:
381
+ api = HfApi(token=oauth_token)
382
+
383
+ # Create Repo (if needed)
384
+ print(f"Creating/Checking repo {repo_name}...")
385
+ repo_url = api.create_repo(
386
+ repo_id=repo_name,
387
+ exist_ok=True
388
+ )
389
+
390
+ # Upload
391
+ print(f"Uploading to {repo_url.repo_id}...")
392
+ api.upload_folder(
393
+ folder_path=str(self.output_dir),
394
+ repo_id=repo_url.repo_id,
395
+ repo_type="model"
396
+ )
397
+ return f"βœ… Success! Model uploaded to: {repo_url}"
398
+ except Exception as e:
399
+ return f"❌ Upload failed: {str(e)}"
requirements.txt CHANGED
@@ -2,5 +2,6 @@ accelerate
2
  datasets
3
  gradio
4
  matplotlib
 
5
  transformers
6
  trl
 
2
  datasets
3
  gradio
4
  matplotlib
5
+ oauth
6
  transformers
7
  trl
ui.py CHANGED
@@ -1,21 +1,64 @@
1
  import gradio as gr
 
2
  from engine import FunctionGemmaEngine
3
 
4
- def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
5
 
6
- # Wrapper: Update config with selected model, then Run Training
7
- def run_training_wrapper(epochs, lr, test_size, shuffle, model_name):
 
 
 
 
 
 
 
 
 
 
 
8
  engine.config.MODEL_NAME = model_name.strip()
9
  yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
10
 
11
- # Wrapper: Update config with selected model, then Reset/Reload
12
- def handle_reset(model_name):
13
  engine.config.MODEL_NAME = model_name.strip()
14
  return engine.refresh_data_and_model()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
17
- gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
18
- gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
 
 
 
 
19
 
20
  with gr.Tabs():
21
 
@@ -27,7 +70,6 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
27
  with gr.Column(scale=1):
28
  gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
29
  tools_editor = gr.Code(
30
- value=engine.get_tools_json(),
31
  language="json",
32
  label="Tool Definitions (JSON Schema)",
33
  lines=15
@@ -52,10 +94,9 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
52
  with gr.Group():
53
  gr.Markdown("**Hyperparameters**")
54
  with gr.Row():
55
- # Dropdown that allows custom typing
56
  param_model = gr.Dropdown(
57
- choices=engine.config.AVAILABLE_MODELS,
58
- value=engine.config.MODEL_NAME,
59
  allow_custom_value=True,
60
  label="Base Model",
61
  info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/gemma-3-1b-it')",
@@ -82,91 +123,107 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
82
  )
83
 
84
  with gr.Row():
85
- run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=2)
86
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
87
  clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)
88
 
89
  with gr.Row():
90
- # Left column: Text Logs
91
  output_display = gr.Textbox(
92
  lines=20,
93
  label="Logs & Results",
94
- value="Ready.",
95
  interactive=False,
96
  autoscroll=True
97
  )
98
- # Right column: Plot
99
  loss_plot = gr.Plot(label="Training Metrics")
100
 
101
  # --- TAB 3: EXPORT ---
102
  with gr.TabItem("3. Export"):
103
  gr.Markdown("### πŸ“¦ Export Trained Model")
104
- gr.Markdown("Download the fine-tuned LoRA adapters or full model weights (depending on configuration) as a ZIP file.")
105
 
106
  with gr.Row():
107
- zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="primary", scale=1)
108
- download_file = gr.File(label="Download Archive", interactive=False, scale=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # --- EVENT WIRING ---
111
 
112
- # Tab 1: Tools
 
 
 
 
 
113
  update_tools_btn.click(
114
- fn=engine.update_tools,
115
- inputs=[tools_editor],
116
  outputs=[tools_status]
117
  )
118
 
119
- # Tab 1: File Import
120
  import_file.upload(
121
- fn=engine.load_csv,
122
- inputs=[import_file],
123
  outputs=[import_status]
124
  )
125
 
126
- # Tab 2: Training (Uses Wrapper)
127
  run_training_btn.click(
128
  fn=lambda: (
129
- gr.update(visible=False), # Hide Run
130
- gr.update(interactive=False), # Disable Reset
131
- gr.update(visible=True) # Show Stop
132
  ),
133
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
134
  ).then(
135
  fn=run_training_wrapper,
136
- inputs=[param_epochs, param_lr, param_test_size, param_shuffle, param_model],
137
  outputs=[output_display, loss_plot],
138
  ).then(
139
  fn=lambda: (
140
- gr.update(visible=True), # Show Run
141
- gr.update(interactive=True), # Enable Reset
142
- gr.update(visible=False) # Hide Stop
143
  ),
144
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
145
  )
146
 
147
- # Tab 2: Stop
148
  stop_training_btn.click(
149
- fn=lambda: (engine.trigger_stop(), "Stopping...")[1],
 
150
  outputs=None
151
  )
152
 
153
- # Tab 2: Reset (Uses Wrapper to capture model name)
154
  clear_reload_btn.click(
155
  fn=handle_reset,
156
- inputs=[param_model],
157
  outputs=[output_display]
158
  )
159
 
160
- # Tab 3: Download
161
- def handle_zip():
162
- path = engine.get_zip_path()
163
- if path:
164
- return gr.update(value=path, visible=True)
165
- return gr.update(value=None, visible=False)
166
-
167
  zip_btn.click(
168
- fn=handle_zip,
 
169
  outputs=[download_file]
170
  )
 
 
 
 
 
 
171
 
172
  return demo
 
1
  import gradio as gr
2
+ from config import AppConfig
3
  from engine import FunctionGemmaEngine
4
 
5
+ def build_interface() -> gr.Blocks:
6
 
7
+ # --- State Management Wrappers ---
8
+
9
+ def init_session():
10
+ config = AppConfig()
11
+ new_engine = FunctionGemmaEngine(config)
12
+ return (
13
+ new_engine,
14
+ new_engine.get_tools_json(),
15
+ new_engine.config.MODEL_NAME,
16
+ f"Ready. (Session {new_engine.session_id})"
17
+ )
18
+
19
+ def run_training_wrapper(engine, epochs, lr, test_size, shuffle, model_name):
20
  engine.config.MODEL_NAME = model_name.strip()
21
  yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
22
 
23
+ def handle_reset(engine, model_name):
 
24
  engine.config.MODEL_NAME = model_name.strip()
25
  return engine.refresh_data_and_model()
26
 
27
+ def update_tools_wrapper(engine, json_val):
28
+ return engine.update_tools(json_val)
29
+
30
+ def import_file_wrapper(engine, file_obj):
31
+ return engine.load_csv(file_obj)
32
+
33
+ def stop_wrapper(engine):
34
+ engine.trigger_stop()
35
+ return "Stopping..."
36
+
37
+ def zip_wrapper(engine):
38
+ path = engine.get_zip_path()
39
+ if path:
40
+ return gr.update(value=path, visible=True)
41
+ return gr.update(value=None, visible=False)
42
+
43
+ def upload_wrapper(engine, repo_name, oauth_token: gr.OAuthToken | None):
44
+ if oauth_token is None:
45
+ return "❌ Error: You must log in (top right) to upload models."
46
+ if not repo_name:
47
+ return "❌ Error: Please enter a repository name."
48
+
49
+ return engine.upload_model_to_hub(
50
+ repo_name=repo_name,
51
+ oauth_token=oauth_token.token,
52
+ )
53
+
54
+ # --- UI Layout ---
55
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
56
+ engine_state = gr.State()
57
+
58
+ with gr.Column():
59
+ gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
60
+ gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
61
+ gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
62
 
63
  with gr.Tabs():
64
 
 
70
  with gr.Column(scale=1):
71
  gr.Markdown("**Step 1: Define Functions**<br>Edit the JSON schema below to define the tools the model should learn.")
72
  tools_editor = gr.Code(
 
73
  language="json",
74
  label="Tool Definitions (JSON Schema)",
75
  lines=15
 
94
  with gr.Group():
95
  gr.Markdown("**Hyperparameters**")
96
  with gr.Row():
97
+ default_models = AppConfig().AVAILABLE_MODELS
98
  param_model = gr.Dropdown(
99
+ choices=default_models,
 
100
  allow_custom_value=True,
101
  label="Base Model",
102
  info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/gemma-3-1b-it')",
 
123
  )
124
 
125
  with gr.Row():
126
+ run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=1)
127
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
128
  clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)
129
 
130
  with gr.Row():
 
131
  output_display = gr.Textbox(
132
  lines=20,
133
  label="Logs & Results",
134
+ value="Initializing...",
135
  interactive=False,
136
  autoscroll=True
137
  )
 
138
  loss_plot = gr.Plot(label="Training Metrics")
139
 
140
  # --- TAB 3: EXPORT ---
141
  with gr.TabItem("3. Export"):
142
  gr.Markdown("### πŸ“¦ Export Trained Model")
 
143
 
144
  with gr.Row():
145
+ with gr.Column():
146
+ gr.Markdown("#### Option A: Download ZIP")
147
+ gr.Markdown("Download the model weights locally.")
148
+ zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary")
149
+ download_file = gr.File(label="Download Archive", interactive=False)
150
+
151
+ with gr.Column():
152
+ gr.Markdown("#### Option B: Upload to Hugging Face Hub")
153
+ gr.Markdown("Publish to your HF profile. **Requires Login**.")
154
+
155
+ with gr.Group():
156
+ repo_id_input = gr.Textbox(
157
+ label="Repository Name",
158
+ placeholder="my-function-gemma-v1",
159
+ info="Will be created under your username (e.g. user/repo)"
160
+ )
161
+ upload_hub_btn = gr.Button("☁️ Upload to Hub", variant="primary")
162
+
163
+ upload_status = gr.Markdown("")
164
 
165
  # --- EVENT WIRING ---
166
 
167
+ demo.load(
168
+ fn=init_session,
169
+ inputs=None,
170
+ outputs=[engine_state, tools_editor, param_model, output_display]
171
+ )
172
+
173
  update_tools_btn.click(
174
+ fn=update_tools_wrapper,
175
+ inputs=[engine_state, tools_editor],
176
  outputs=[tools_status]
177
  )
178
 
 
179
  import_file.upload(
180
+ fn=import_file_wrapper,
181
+ inputs=[engine_state, import_file],
182
  outputs=[import_status]
183
  )
184
 
 
185
  run_training_btn.click(
186
  fn=lambda: (
187
+ gr.update(visible=False),
188
+ gr.update(interactive=False),
189
+ gr.update(visible=True)
190
  ),
191
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
192
  ).then(
193
  fn=run_training_wrapper,
194
+ inputs=[engine_state, param_epochs, param_lr, param_test_size, param_shuffle, param_model],
195
  outputs=[output_display, loss_plot],
196
  ).then(
197
  fn=lambda: (
198
+ gr.update(visible=True),
199
+ gr.update(interactive=True),
200
+ gr.update(visible=False)
201
  ),
202
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
203
  )
204
 
 
205
  stop_training_btn.click(
206
+ fn=stop_wrapper,
207
+ inputs=[engine_state],
208
  outputs=None
209
  )
210
 
 
211
  clear_reload_btn.click(
212
  fn=handle_reset,
213
+ inputs=[engine_state, param_model],
214
  outputs=[output_display]
215
  )
216
 
 
 
 
 
 
 
 
217
  zip_btn.click(
218
+ fn=zip_wrapper,
219
+ inputs=[engine_state],
220
  outputs=[download_file]
221
  )
222
+
223
+ upload_hub_btn.click(
224
+ fn=upload_wrapper,
225
+ inputs=[engine_state, repo_id_input],
226
+ outputs=[upload_status]
227
+ )
228
 
229
  return demo