File size: 3,962 Bytes
f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 f6d52a4 b2938e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | #!/usr/bin/env python3
import os, threading
os.environ["WANDB_DISABLED"] = "true"
# Health server so HF doesn't kill us for timeout
from http.server import HTTPServer, BaseHTTPRequestHandler
class H(BaseHTTPRequestHandler):
status = "starting"
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type","text/html")
self.end_headers()
self.wfile.write(f"<h1>FunctionGemma Training</h1><p>Status: {H.status}</p>".encode())
def log_message(self, *a): pass
server = HTTPServer(("0.0.0.0", 7860), H)
threading.Thread(target=server.serve_forever, daemon=True).start()
print("Health server on :7860")
print("=== Installing ===")
H.status = "installing dependencies"
import subprocess
subprocess.check_call(["pip", "install", "-q", "unsloth", "trl", "datasets", "peft", "accelerate", "bitsandbytes"])
print("=== Loading model ===")
H.status = "loading model"
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(model_name="google/functiongemma-270m-it", max_seq_length=4096, load_in_4bit=False)
print("=== Applying LoRA ===")
H.status = "applying LoRA"
model = FastLanguageModel.get_peft_model(model, r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth")
print("=== Loading dataset ===")
H.status = "loading dataset"
from datasets import load_dataset
dataset = load_dataset("eyalnof123/su-lab-functiongemma-dataset")
train_dataset = dataset["train"] if "train" in dataset else load_dataset("eyalnof123/su-lab-functiongemma-dataset", data_files="train.jsonl", split="train")
print(f"Training examples: {len(train_dataset)}")
print("=== Training ===")
H.status = "training epoch 1/3"
from trl import SFTTrainer, SFTConfig
class StatusCallback:
def on_log(self, args, state, control, logs=None, **kw):
epoch = state.epoch or 0
H.status = f"training step {state.global_step}/{state.max_steps} epoch {epoch:.1f}"
training_args = SFTConfig(
output_dir="./output", num_train_epochs=3, per_device_train_batch_size=2,
gradient_accumulation_steps=4, learning_rate=2e-4, weight_decay=0.01,
lr_scheduler_type="linear", warmup_steps=5, logging_steps=10,
save_strategy="epoch", bf16=True, fp16=False, optim="adamw_8bit",
max_seq_length=4096, dataset_text_field="text", seed=42)
from transformers import TrainerCallback
class SC(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kw):
epoch = state.epoch or 0
H.status = f"training step {state.global_step}/{state.max_steps} epoch {epoch:.1f}"
def on_epoch_end(self, args, state, control, **kw):
H.status = f"saving checkpoint epoch {int(state.epoch)}"
trainer = SFTTrainer(model=model, args=training_args, train_dataset=train_dataset,
tokenizer=tokenizer, callbacks=[SC()])
trainer.train()
print("=== Saving ===")
H.status = "saving model"
model.save_pretrained("./output/final")
tokenizer.save_pretrained("./output/final")
print("=== Pushing to Hub ===")
H.status = "uploading to hub"
from huggingface_hub import HfApi, create_repo
api = HfApi()
try: create_repo("eyalnof123/functiongemma-270m-su-lab", private=False)
except: pass
api.upload_folder(folder_path="./output/final", repo_id="eyalnof123/functiongemma-270m-su-lab", repo_type="model")
print("=== GGUF ===")
H.status = "exporting GGUF"
try:
model.save_pretrained_gguf("./output/gguf", tokenizer, quantization_method="q4_k_m")
api.upload_folder(folder_path="./output/gguf", repo_id="eyalnof123/functiongemma-270m-su-lab", repo_type="model", path_in_repo="gguf")
print("GGUF uploaded!")
except Exception as e:
print(f"GGUF failed: {e}")
H.status = "DONE! Model at https://huggingface.co/eyalnof123/functiongemma-270m-su-lab"
print(H.status)
# Keep alive so you can see the status
import time
while True: time.sleep(60)
|