bebechien's picture
Upload folder using huggingface_hub
c055e6e verified
raw
history blame
19.3 kB
import gradio as gr
import os
import json
import torch
import csv
import shutil
import time
import threading
from typing import Final, Optional, List, Any, Generator
from pathlib import Path
from dataclasses import dataclass
from huggingface_hub import login
from trl import SFTConfig, SFTTrainer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainerCallback,
TrainingArguments,
TrainerControl,
TrainerState
)
from datasets import Dataset, load_dataset
# --- Configuration ---
class AppConfig:
"""
Central configuration class.
"""
ARTIFACTS_DIR: Final[Path] = Path("artifacts")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
MODEL_NAME: Final[str] = '../hf/270m'
DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-modkit-demo")
# --- Tool Definitions ---
def search_knowledge_base(query: str) -> str:
"""
Search internal company documents, policies and project data.
Args:
query: query string
"""
return "Interal Result"
def search_google(query: str) -> str:
"""
Search public information.
Args:
query: query string
"""
return "Public Result"
search_knowledge_base_schema = {
"type": "function",
"function": {
"name": "search_knowledge_base",
"description": "Search internal company documents, policies and project data.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "query string"
}
},
"required": [
"query"
]
},
"return": {
"type": "string"
}
}
}
search_google_schema = {
"type": "function",
"function": {
"name": "search_google",
"description": "Search public information.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "query string"
}
},
"required": [
"query"
]
},
"return": {
"type": "string"
}
}
}
TOOLS = [search_knowledge_base_schema, search_google_schema]
DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
# --- Callbacks ---
class AbortCallback(TrainerCallback):
"""
A custom callback to check a threading Event to stop training on user request.
"""
def __init__(self, stop_event: threading.Event):
self.stop_event = stop_event
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if self.stop_event.is_set():
print("πŸ›‘ Stop signal received. Stopping training...")
control.should_training_stop = True
# --- Helper Functions ---
def authenticate_hf(token: Optional[str]) -> None:
"""Logs into the Hugging Face Hub."""
if token:
print("Logging into Hugging Face Hub...")
login(token=token)
else:
print("Skipping Hugging Face login: HF_TOKEN not set.")
def load_model_and_tokenizer(model_name: str):
print(f"Loading Transformer model: {model_name}")
try:
# Check if local path exists, otherwise treat as HF Hub ID
if model_name.startswith("..") and not os.path.exists(model_name):
print(f"Warning: Local path {model_name} not found. Falling back to default hub model.")
model_name = "google/gemma-2b-it" # Fallback example
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
print("Model loaded successfully.")
return model, tokenizer
except Exception as e:
print(f"Error loading Transformer model {model_name}: {e}")
raise e
def create_conversation_format(sample):
"""Formats a dataset row into the conversational format required for SFT."""
try:
tool_args = json.loads(sample["tool_arguments"])
except (json.JSONDecodeError, TypeError):
tool_args = {}
return {
"messages": [
{"role": "developer", "content": DEFAULT_SYSTEM_MSG},
{"role": "user", "content": sample["user_content"]},
{"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]},
],
"tools": TOOLS
}
# --- Main Application Logic ---
class FunctionGemmaTuner:
def __init__(self, config: AppConfig = AppConfig):
self.config = config
self.model = None
self.tokenizer = None
self.imported_dataset = []
# Threading event to control stopping
self.stop_event = threading.Event()
authenticate_hf(self.config.HF_TOKEN)
# Initial load attempt
print("--- Running Initial Data Load ---")
try:
self.refresh_data_and_model()
print("--- Initial Load Complete ---")
except Exception as e:
print(f"Initial load failed (this is common if model path is invalid): {e}")
def refresh_data_and_model(self):
"""Reloads the model and clears imported data."""
print("\n" + "=" * 50)
print("RELOADING MODEL and RE-FETCHING DATA")
self.imported_dataset = []
try:
self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
status_value = "Model and data reloaded. Ready."
except Exception as e:
self.model = None
self.tokenizer = None
status_value = f"CRITICAL ERROR: Model failed to load. {e}"
# We don't raise here to allow the UI to render the error message
return status_value
def import_additional_dataset(self, file_path: str) -> str:
"""Parses an uploaded CSV file."""
if not file_path:
return "Please upload a CSV file."
new_dataset = []
num_imported = 0
try:
# Open file handle properly
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
# Basic header validation
try:
header = next(reader)
# Simple heuristic check, allows skipping header or rewinding
if not (header and "anchor" in header[0].lower()):
f.seek(0)
except StopIteration:
return "Error: Uploaded file is empty."
for row in reader:
# Expecting: [User Prompt, Tool Name, Tool Args JSON/String]
if len(row) >= 3:
new_dataset.append([s.strip() for s in row[:3]])
num_imported += 1
if num_imported == 0:
return "No valid rows found. CSV format: [Anchor, Positive, Negative]"
self.imported_dataset = new_dataset
return f"Successfully imported {num_imported} additional training samples."
except Exception as e:
return f"Import failed. Error: {e}"
def stop_training(self):
"""Signal the training loop to stop."""
print("Set stop event")
self.stop_event.set()
return "Stopping initiated... please wait for the current step to finish."
def run_training(self, test_size: float = 0.5) -> Generator[str, None, None]:
"""
Main training logic. Yields status strings to the UI.
"""
# 1. Validation
if self.model is None:
yield "Training failed: Model is not loaded."
return
self.stop_event.clear() # Reset stop flag
yield "⏳ Preparing Dataset..."
# 2. Dataset Preparation
if not self.imported_dataset:
print("No imported dataset, using default HF dataset")
try:
dataset = load_dataset(self.config.DEFAULT_DATASET, split="train")
except Exception as e:
yield f"Error loading default dataset: {e}"
return
else:
dataset_as_dicts = [{
"user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
for row in self.imported_dataset
]
dataset = Dataset.from_list(dataset_as_dicts)
# Apply formatting
dataset = dataset.map(create_conversation_format, batched=False)
# Split
if len(dataset) > 1:
dataset = dataset.train_test_split(test_size=test_size, shuffle=False)
else:
# Fallback for very small datasets (mostly for debugging)
dataset = {"train": dataset, "test": dataset}
output_buffer = "πŸ“Š Evaluating Pre-Training Success Rate...\n### Success Rate (Before Training):\n"
yield output_buffer
pre_training_report = ""
gen = self.check_success_rate(dataset["test"])
while not self.stop_event.is_set():
try:
pre_training_report += f"{next(gen)}\n"
yield f"{output_buffer}{pre_training_report}"
except StopIteration as e:
pre_training_report = e.value
break
if self.stop_event.is_set():
output_buffer += f"{pre_training_report}\n\nπŸ›‘ Manual Eval interrupted by user.\n"
yield output_buffer
return
output_buffer += f"{pre_training_report}\n\n"
output_buffer += "-" * 30 + "\nStarting Fine-tuning...\n"
yield output_buffer
# 3. Training Setup
torch_dtype = self.model.dtype
args = SFTConfig(
output_dir=str(self.config.OUTPUT_DIR),
max_length=512,
packing=False,
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_checkpointing=False,
optim="adamw_torch_fused",
logging_steps=1,
save_strategy="no", # Speed up demo
eval_strategy="epoch",
learning_rate=5e-5,
fp16=True if torch_dtype == torch.float16 else False,
bf16=True if torch_dtype == torch.bfloat16 else False,
lr_scheduler_type="constant",
push_to_hub=False,
report_to="none",
dataset_kwargs={
"add_special_tokens": False,
"append_concat_token": True,
}
)
trainer = SFTTrainer(
model=self.model,
args=args,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
processing_class=self.tokenizer,
callbacks=[AbortCallback(self.stop_event)] # Inject our stopper
)
# 4. Run Training
try:
output_buffer += "πŸš€ Training in progress... (Click Stop to interrupt)\n"
yield output_buffer
trainer.train()
if self.stop_event.is_set():
output_buffer += "\nπŸ›‘ Training interrupted by user.\n"
else:
output_buffer += "\nβœ… Training finished. Model weights updated in memory.\n"
yield output_buffer
# Save locally
trainer.save_model()
output_buffer += f"Model saved locally to: {self.config.OUTPUT_DIR}\n"
yield output_buffer
except Exception as e:
output_buffer += f"\n❌ Error during training: {e}\n"
yield output_buffer
return
if self.stop_event.is_set():
return
# 5. Post-Evaluation
output_buffer += "πŸ“Š Evaluating Post-Training Success Rate...\n"
post_report = ""
yield output_buffer
gen = self.check_success_rate(dataset["test"])
while not self.stop_event.is_set():
try:
post_report += f"{next(gen)}\n"
yield f"{output_buffer}{post_report}"
except StopIteration as e:
post_report = e.value
break
if self.stop_event.is_set():
output_buffer += f"{post_report}\n\nπŸ›‘ Manual Eval interrupted by user.\n"
yield output_buffer
return
output_buffer += f"{post_report}\n\n"
yield output_buffer
def check_success_rate(self, test_dataset):
"""Runs inference on test set to calculate accuracy."""
results = []
success_count = 0
total = len(test_dataset)
for idx, item in enumerate(test_dataset):
if idx >= 5:
break
if self.stop_event.is_set():
break
messages = [item["messages"][0], item["messages"][1]] # System + User
try:
inputs = self.tokenizer.apply_chat_template(
messages,
tools=TOOLS,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
out = self.model.generate(
**inputs.to(self.model.device),
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=128
)
# Decode only the new tokens
output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True)
results.append(f"{idx+1}. Prompt: {item['messages'][1]['content']}")
yield results[-1]
results.append(f" Output: {output[:100]}...")
yield results[-1]
# Check for correct tool name usage
expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
if expected_tool in output:
results.append(" -> βœ… Correct Tool")
yield results[-1]
success_count += 1
else:
results.append(f" -> ❌ Wrong Tool (Expected: {expected_tool})")
yield results[-1]
except Exception as e:
results.append(f" -> Error: {e}")
yield results[-1]
summary = "\n".join(results)
summary += f"\n\nTotal Success : {success_count} / {len(test_dataset)}"
return summary
def download_model_zip(self) -> Optional[str]:
"""Zips the output directory for download."""
if not os.path.exists(self.config.OUTPUT_DIR):
return None
timestamp = int(time.time())
try:
base_name = self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}")
archive_path = shutil.make_archive(
base_name=str(base_name),
format='zip',
root_dir=str(self.config.OUTPUT_DIR),
)
return archive_path
except Exception as e:
print(f"Zip failed: {e}")
return None
# --- UI Builder ---
def build_interface(self) -> gr.Blocks:
with gr.Blocks(title="FunctionGemma Modkit") as demo:
gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.")
with gr.Column():
gr.Markdown("## 1. Training Controls")
with gr.Row():
run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
stop_training_btn = gr.Button("πŸ›‘ Stop Training", variant="stop", visible=False)
output_display = gr.Textbox(
lines=14,
label="Training Logs & Search Results",
value="Ready. Click 'Run' to begin.",
interactive=False
)
clear_reload_btn = gr.Button("πŸ”„ Reset Model & Data")
gr.Markdown("--- \n ## 2. Data Management")
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=80)
import_status = gr.Markdown("")
gr.Markdown("--- \n ## 3. Export")
with gr.Row():
zip_btn = gr.Button("⬇️ Prepare Model ZIP")
download_file = gr.File(label="Download ZIP", height=80, visible=True, interactive=False)
# --- Event Wiring ---
# Start Training (Generator updates output_display)
run_training_btn.click(
fn=lambda: (
gr.update(visible=False),
gr.update(interactive=False),
gr.update(visible=True)
),
inputs=None,
outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
).then(
fn=self.run_training,
inputs=[],
outputs=[output_display],
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(interactive=True),
gr.update(visible=False)
),
inputs=None,
outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
)
# Stop Training
stop_training_btn.click(
fn=self.stop_training,
inputs=None,
outputs=None # We don't need to return anything, status updates via the training generator
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(interactive=True),
gr.update(visible=False)
),
inputs=None,
outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
)
# Reload
clear_reload_btn.click(
fn=self.refresh_data_and_model,
inputs=None,
outputs=[output_display]
)
# File Import
import_file.upload(
fn=self.import_additional_dataset,
inputs=[import_file],
outputs=[import_status]
)
# Download Logic
def handle_zip():
path = self.download_model_zip()
if path:
return gr.update(value=path, visible=True)
return gr.update(value=None, visible=False)
zip_btn.click(
fn=handle_zip,
inputs=None,
outputs=[download_file]
)
return demo
if __name__ == "__main__":
app = FunctionGemmaTuner(AppConfig)
demo = app.build_interface()
print("Starting Gradio App...")
demo.launch()