bebechien commited on
Commit
6be610b
Β·
verified Β·
1 Parent(s): 7b994a4

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +5 -3
  2. app.py +311 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Functiongemma Modkit
3
  emoji: πŸ“Š
4
  colorFrom: gray
5
  colorTo: indigo
@@ -8,7 +8,9 @@ sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: FunctionGemma Modkit
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: FunctionGemma Modkit
3
  emoji: πŸ“Š
4
  colorFrom: gray
5
  colorTo: indigo
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
+ # FunctionGemma Modkit
14
+
15
+ This project provides a set of tools to fine-tune FunctionGemma to understand your personal needs.
16
+
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import torch
5
+
6
+ from typing import Final, Optional, List
7
+ from pathlib import Path
8
+ from huggingface_hub import login
9
+ from trl import SFTConfig, SFTTrainer
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from datasets import Dataset, load_dataset
12
+ from transformers.utils import get_json_schema
13
+
14
+
15
+ ARTIFACTS_DIR: Final[Path] = Path("artifacts")
16
+
17
+ def authenticate_hf(token: Optional[str]) -> None:
18
+ """Logs into the Hugging Face Hub."""
19
+ if token:
20
+ print("Logging into Hugging Face Hub...")
21
+ login(token=token)
22
+ else:
23
+ print("Skipping Hugging Face login: HF_TOKEN not set.")
24
+
25
+ def load_model(model_name: str):
26
+ print(f"Loading Transformer model: {model_name}")
27
+ try:
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(model_name)
30
+ print("Model loaded successfully.")
31
+ return (model, tokenizer)
32
+ except Exception as e:
33
+ print(f"Error loading Transformer model {model_name}: {e}")
34
+ raise
35
+
36
+ # --- Tool Definitions ---
37
+ def search_knowledge_base(query: str) -> str:
38
+ """
39
+ Search internal company documents, policies and project data.
40
+
41
+ Args:
42
+ query: query string
43
+ """
44
+ return "Interal Result"
45
+
46
+ def search_google(query: str) -> str:
47
+ """
48
+ Search public information.
49
+
50
+ Args:
51
+ query: query string
52
+ """
53
+ return "Public Result"
54
+
55
+
56
+ TOOLS = [get_json_schema(search_knowledge_base), get_json_schema(search_google)]
57
+
58
+ DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
59
+
60
+
61
+ def create_conversation(sample):
62
+ return {
63
+ "messages": [
64
+ {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
65
+ {"role": "user", "content": sample["user_content"]},
66
+ {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": json.loads(sample["tool_arguments"])}}]},
67
+ ],
68
+ "tools": TOOLS
69
+ }
70
+
71
+
72
+ def train_with_dataset(
73
+ model: AutoModelForCausalLM,
74
+ tokenizer: AutoTokenizer,
75
+ dataset: Dataset,
76
+ output_dir: Path,
77
+ learning_rate: float = 5e-5
78
+ ) -> None:
79
+
80
+ torch_dtype = model.dtype
81
+
82
+ args = SFTConfig(
83
+ output_dir=output_dir, # directory to save and repository id
84
+ max_length=512, # max sequence length for model and packing of the dataset
85
+ packing=False, # Groups multiple samples in the dataset into a single sequence
86
+ num_train_epochs=5, # number of training epochs
87
+ per_device_train_batch_size=4, # batch size per device during training
88
+ gradient_checkpointing=False, # Caching is incompatible with gradient checkpointing
89
+ optim="adamw_torch_fused", # use fused adamw optimizer
90
+ logging_steps=1, # log every step
91
+ #save_strategy="epoch", # save checkpoint every epoch
92
+ eval_strategy="epoch", # evaluate checkpoint every epoch
93
+ learning_rate=learning_rate, # learning rate
94
+ fp16=True if torch_dtype == torch.float16 else False, # use float16 precision
95
+ bf16=True if torch_dtype == torch.bfloat16 else False, # use bfloat16 precision
96
+ lr_scheduler_type="constant", # use constant learning rate scheduler
97
+ push_to_hub=False, # push model to hub
98
+ report_to="none", # report metrics to tensorboard
99
+ dataset_kwargs={
100
+ "add_special_tokens": False, # Template with special tokens
101
+ "append_concat_token": True, # Add EOS token as separator token between examples
102
+ }
103
+ )
104
+
105
+ # Create Trainer object
106
+ trainer = SFTTrainer(
107
+ model=model,
108
+ args=args,
109
+ train_dataset=dataset['train'],
110
+ eval_dataset=dataset['test'],
111
+ processing_class=tokenizer,
112
+ )
113
+
114
+ trainer.train()
115
+
116
+ print("Training finished. Model weights are updated in memory.")
117
+
118
+ # Save the final fine-tuned model
119
+ trainer.save_model()
120
+
121
+ print(f"Model saved locally to: {output_dir}")
122
+
123
+ class AppConfig:
124
+ """
125
+ Central configuration class for the Fine-Tuner application.
126
+ """
127
+ ARTIFACTS_DIR: Final[Path] = ARTIFACTS_DIR
128
+ HF_TOKEN: Final[str | None] = os.getenv('HF_TOKEN')
129
+ MODEL_NAME: Final[str] = '../hf/270m'
130
+ DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
131
+ OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-270m-it-modkit-demo")
132
+
133
+
134
+ class FunctionGemmaTuner:
135
+ def __init__(self, config: AppConfig = AppConfig):
136
+ self.config = config
137
+
138
+ os.makedirs(self.config.ARTIFACTS_DIR, exist_ok=True)
139
+ print(f"Created artifact directory: {self.config.ARTIFACTS_DIR}")
140
+
141
+ authenticate_hf(self.config.HF_TOKEN)
142
+
143
+ self._initial_load()
144
+
145
+ def _initial_load(self):
146
+ """Helper to run the refresh function once at startup."""
147
+ print("--- Running Initial Data Load ---")
148
+ self.refresh_data_and_model()
149
+ print("--- Initial Load Complete ---")
150
+
151
+ def refresh_data_and_model(self):
152
+ print("\n" + "=" * 50)
153
+ print("RELOADING MODEL and RE-FETCHING DATA")
154
+
155
+ # Reset dataset state
156
+ self.imported_dataset = []
157
+
158
+ # 1. Reload the base model
159
+ try:
160
+ self.model, self.tokenizer = load_model(self.config.MODEL_NAME)
161
+ except Exception as e:
162
+ gr.Error(f"Model load failed: {e}")
163
+ self.model = None
164
+ self.tokenizer = None
165
+ return gr.update(value=f"CRITICAL ERROR: Model failed to load. {e}")
166
+
167
+ status_value: str = f"Model and data reloaded. Click 'Run Fine-Tuning' to begin."
168
+
169
+ # Return Gradio updates for CheckboxGroup and Textbox
170
+ return gr.update(value=status_value)
171
+
172
+ # --- Import Dataset/Export ---
173
+ def import_additional_dataset(self, file_path: str) -> str:
174
+ if not file_path:
175
+ return "Please upload a CSV file."
176
+ new_dataset, num_imported = [], 0
177
+ try:
178
+ with open(file_path, 'r', newline='', encoding='utf-8') as f:
179
+ reader = csv.reader(f)
180
+ try:
181
+ header = next(reader)
182
+ if not (header and header[0].lower().strip() == 'anchor'):
183
+ f.seek(0)
184
+ except StopIteration:
185
+ return "Error: Uploaded file is empty."
186
+
187
+ for row in reader:
188
+ if len(row) == 3:
189
+ new_dataset.append([s.strip() for s in row])
190
+ num_imported += 1
191
+ if num_imported == 0:
192
+ raise ValueError("No valid [Anchor, Positive, Negative] rows found in the CSV.")
193
+ self.imported_dataset = new_dataset
194
+ return f"Successfully imported {num_imported} additional training triplets."
195
+ except Exception as e:
196
+ gr.Error(f"Import failed. Ensure the CSV format is: [Anchor, Positive, Negative]. Error: {e}")
197
+ return "Import failed. Check console for details."
198
+
199
+ def download_model(self) -> Optional[str]:
200
+ if not os.path.exists(self.config.OUTPUT_DIR):
201
+ gr.Warning(f"The model directory '{self.config.OUTPUT_DIR}' does not exist. Please run training first.")
202
+ return None
203
+ timestamp = int(time.time())
204
+ try:
205
+ base_name = os.path.join(self.config.ARTIFACTS_DIR, f"embedding_gemma_finetuned_{timestamp}")
206
+ archive_path = shutil.make_archive(
207
+ base_name=base_name,
208
+ format='zip',
209
+ root_dir=self.config.OUTPUT_DIR,
210
+ )
211
+ gr.Info(f"Model files successfully zipped to: {archive_path}")
212
+ return archive_path
213
+ except Exception as e:
214
+ gr.Error(f"Failed to create the model ZIP file. Error: {e}")
215
+ return None
216
+
217
+ def training(self, test_size: float = 0.5) -> str:
218
+ """
219
+ Generates a training dataset from user selection and runs the fine-tuning process.
220
+ """
221
+ if self.model is None:
222
+ raise gr.Error("Training failed: Model is not loaded.")
223
+
224
+ if not self.imported_dataset:
225
+ print("No imported dataset, use the default")
226
+ dataset = load_dataset(self.config.DEFAULT_DATASET, split="train")
227
+ else:
228
+ dataset_as_dicts = [{
229
+ "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
230
+ for row in self.imported_dataset
231
+ ]
232
+ dataset = Dataset.from_list(dataset_as_dicts)
233
+
234
+ dataset = dataset.map(create_conversation, batched=False)
235
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=False)
236
+ print(dataset)
237
+ print("--- dataset input ---")
238
+ print(json.dumps(dataset["train"][0], indent=2))
239
+ debug_msg = self.tokenizer.apply_chat_template(dataset["train"][0]["messages"], tools=dataset["train"][0]["tools"], add_generation_prompt=False, tokenize=False)
240
+ print("--- Formatted prompt ---")
241
+ print(debug_msg)
242
+
243
+ result = "### Success Rate (Before Training):\n" + f"{self.check_success_rate(dataset["test"])}\n\n"
244
+ print("-" * 50 + "\nStarting Fine-tuning...")
245
+ train_with_dataset(model=self.model, tokenizer=self.tokenizer, dataset=dataset, output_dir=self.config.OUTPUT_DIR)
246
+ print("Fine-tuning Complete.\n" + "-" * 50)
247
+
248
+ result += "### Success Rate (After Training):\n" + f"{self.check_success_rate(dataset["test"])}\n\n"
249
+ return result
250
+
251
+ def check_success_rate(self, test_dataset):
252
+ result = []
253
+ success_count = 0
254
+ for idx, item in enumerate(test_dataset):
255
+ messages = [
256
+ item["messages"][0],
257
+ item["messages"][1],
258
+ ]
259
+
260
+ inputs = self.tokenizer.apply_chat_template(messages, tools=TOOLS, add_generation_prompt=True, return_dict=True, return_tensors="pt")
261
+
262
+ out = self.model.generate(**inputs.to(self.model.device), pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=128)
263
+ output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=False)
264
+
265
+ result.append(f"{idx+1} Prompt: {item['messages'][1]['content']}")
266
+ result.append(f" Output: {output}")
267
+ if item['messages'][2]['tool_calls'][0]['function']['name'] in output:
268
+ result.append(" `-> βœ… correct!")
269
+ success_count += 1
270
+ else:
271
+ result.append(" `-> ❌ wrong tool")
272
+
273
+ result.append(f"Success : {success_count} / {len(test_dataset)}")
274
+
275
+ return result
276
+
277
+ def build_interface(self) -> gr.Blocks:
278
+ with gr.Blocks(title="FunctionGemma Modkit") as demo:
279
+ gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
280
+ gr.Markdown("This project provides a set of tools to fine-tune FunctionGemma to understand your personal needs.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
281
+ self._build_training_interface()
282
+ return demo
283
+
284
+ def _build_training_interface(self):
285
+ with gr.Column():
286
+ gr.Markdown("## Fine-Tuning")
287
+ with gr.Row():
288
+ output = gr.Textbox(lines=14, label="Training and Search Results", value="Click 'Run Fine-Tuning' to begin.")
289
+ with gr.Row():
290
+ clear_reload_btn = gr.Button("Clear & Reload Model/Data")
291
+ run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
292
+ gr.Markdown("--- \n ## Dataset & Model Management")
293
+ import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
294
+ with gr.Row():
295
+ download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
296
+ download_status = gr.Markdown("Ready.")
297
+ with gr.Row():
298
+ model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
299
+
300
+ run_training_btn.click(fn=self.training, outputs=output)
301
+ clear_reload_btn.click(fn=self.refresh_data_and_model, inputs=None, outputs=[output], queue=False)
302
+ import_file.change(fn=self.import_additional_dataset, inputs=[import_file], outputs=download_status)
303
+ download_model_btn.click(lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False).then(self.download_model, None, model_output).then(lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status])
304
+
305
+
306
+ if __name__ == "__main__":
307
+ app = FunctionGemmaTuner(AppConfig)
308
+ demo = app.build_interface()
309
+ print("Starting Gradio App...")
310
+ demo.launch()
311
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ accelerate
2
+ datasets
3
+ gradio
4
+ transformers
5
+ trl