eyalnof123's picture
Upload train.py with huggingface_hub
b2938e0 verified
#!/usr/bin/env python3
import os, threading
os.environ["WANDB_DISABLED"] = "true"
# Health server so HF doesn't kill us for timeout
from http.server import HTTPServer, BaseHTTPRequestHandler
class H(BaseHTTPRequestHandler):
status = "starting"
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type","text/html")
self.end_headers()
self.wfile.write(f"<h1>FunctionGemma Training</h1><p>Status: {H.status}</p>".encode())
def log_message(self, *a): pass
server = HTTPServer(("0.0.0.0", 7860), H)
threading.Thread(target=server.serve_forever, daemon=True).start()
print("Health server on :7860")
print("=== Installing ===")
H.status = "installing dependencies"
import subprocess
subprocess.check_call(["pip", "install", "-q", "unsloth", "trl", "datasets", "peft", "accelerate", "bitsandbytes"])
print("=== Loading model ===")
H.status = "loading model"
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(model_name="google/functiongemma-270m-it", max_seq_length=4096, load_in_4bit=False)
print("=== Applying LoRA ===")
H.status = "applying LoRA"
model = FastLanguageModel.get_peft_model(model, r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth")
print("=== Loading dataset ===")
H.status = "loading dataset"
from datasets import load_dataset
dataset = load_dataset("eyalnof123/su-lab-functiongemma-dataset")
train_dataset = dataset["train"] if "train" in dataset else load_dataset("eyalnof123/su-lab-functiongemma-dataset", data_files="train.jsonl", split="train")
print(f"Training examples: {len(train_dataset)}")
print("=== Training ===")
H.status = "training epoch 1/3"
from trl import SFTTrainer, SFTConfig
class StatusCallback:
def on_log(self, args, state, control, logs=None, **kw):
epoch = state.epoch or 0
H.status = f"training step {state.global_step}/{state.max_steps} epoch {epoch:.1f}"
training_args = SFTConfig(
output_dir="./output", num_train_epochs=3, per_device_train_batch_size=2,
gradient_accumulation_steps=4, learning_rate=2e-4, weight_decay=0.01,
lr_scheduler_type="linear", warmup_steps=5, logging_steps=10,
save_strategy="epoch", bf16=True, fp16=False, optim="adamw_8bit",
max_seq_length=4096, dataset_text_field="text", seed=42)
from transformers import TrainerCallback
class SC(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kw):
epoch = state.epoch or 0
H.status = f"training step {state.global_step}/{state.max_steps} epoch {epoch:.1f}"
def on_epoch_end(self, args, state, control, **kw):
H.status = f"saving checkpoint epoch {int(state.epoch)}"
trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_dataset,
tokenizer=tokenizer, callbacks=[SC()])
trainer.train()
print("=== Saving ===")
H.status = "saving model"
model.save_pretrained("./output/final")
tokenizer.save_pretrained("./output/final")
print("=== Pushing to Hub ===")
H.status = "uploading to hub"
from huggingface_hub import HfApi, create_repo
api = HfApi()
try: create_repo("eyalnof123/functiongemma-270m-su-lab", private=False)
except: pass
api.upload_folder(folder_path="./output/final", repo_id="eyalnof123/functiongemma-270m-su-lab", repo_type="model")
print("=== GGUF ===")
H.status = "exporting GGUF"
try:
model.save_pretrained_gguf("./output/gguf", tokenizer, quantization_method="q4_k_m")
api.upload_folder(folder_path="./output/gguf", repo_id="eyalnof123/functiongemma-270m-su-lab", repo_type="model", path_in_repo="gguf")
print("GGUF uploaded!")
except Exception as e:
print(f"GGUF failed: {e}")
H.status = "DONE! Model at https://huggingface.co/eyalnof123/functiongemma-270m-su-lab"
print(H.status)
# Keep alive so you can see the status
import time
while True: time.sleep(60)