amarorn / scripts /hf_sft_train.py
beAnalytic's picture
feat: sync main with feature/superbet-live-inplay
16c19b8 verified
Raw
History Blame Contribute Delete
3.07 kB
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "datasets>=2.18.0",
# "transformers>=4.40.0",
# "trl>=0.12.0",
# "peft>=0.12.0",
# "accelerate>=0.30.0",
# "trackio",
# ]
# ///
from __future__ import annotations
import argparse
from datetime import datetime, timezone
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="SFT do modelo de bolão no Hugging Face Jobs")
parser.add_argument("--dataset-repo", type=str, required=True, help="Repo dataset no Hub")
parser.add_argument("--data-file", type=str, default="training/bolao_train.jsonl")
parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
parser.add_argument("--hub-model-id", type=str, required=True)
parser.add_argument("--max-steps", type=int, default=200)
parser.add_argument("--learning-rate", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--grad-accum", type=int, default=8)
return parser.parse_args()
def main() -> None:
args = parse_args()
ds = load_dataset(
args.dataset_repo,
data_files={"train": args.data_file},
split="train",
)
if len(ds) < 20:
raise ValueError(f"Dataset muito pequeno para treino: {len(ds)} linhas")
def _to_text(row: dict) -> dict:
prompt = row.get("prompt", "")
completion = row.get("completion", "")
text = f"{prompt}\nResposta:\n{completion}".strip()
return {"text": text}
ds = ds.map(_to_text)
ds = ds.remove_columns([c for c in ds.column_names if c != "text"])
split = ds.train_test_split(test_size=0.05, seed=42)
run_suffix = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_name = f"bolao-sft-{run_suffix}"
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
sft_args = SFTConfig(
output_dir="outputs/bolao-sft",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
max_steps=args.max_steps,
warmup_ratio=0.05,
logging_steps=10,
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
bf16=True,
report_to="trackio",
project="api-noticia-bolao",
run_name=run_name,
push_to_hub=True,
hub_model_id=args.hub_model_id,
hub_strategy="every_save",
)
trainer = SFTTrainer(
model=args.base_model,
train_dataset=split["train"],
eval_dataset=split["test"],
args=sft_args,
peft_config=peft_config,
)
trainer.train()
trainer.push_to_hub()
if __name__ == "__main__":
main()