raunch-training-scripts / train_lora.py
4moha's picture
fix: pre-render chat template via .map() instead of formatting_func
de608a4 verified
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.45",
# "datasets>=2.20",
# "accelerate>=0.34",
# "trackio",
# "unsloth",
# ]
# ///
"""Phase-A LoRA SFT for the raunch page-mode model β€” runs inside HF Jobs.
Base: Sao10K/Llama-3.1-8B-Stheno-v3.4
Dataset: 4moha/raunch-page-mode-v0 (private)
Output: pushed to 4moha/raunch-stheno-v3.4-lora-v0
NSFW-only: training data is raunch's NSFW Claude-generated prose. The resulting
LoRA is deployed to the raunch server instance, NOT the SFW lili server.
This script is submitted as the body of the HF Job; it expects the env vars
HF_TOKEN, HF_DATASET_REPO, HF_MODEL_REPO to be set in the job environment.
"""
# Unsloth MUST be imported before transformers/trl/peft β€” its module-init patches
# don't apply otherwise and you get the "imported late" warning + degraded perf.
from unsloth import FastLanguageModel
import os
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
BASE_MODEL = "Sao10K/Llama-3.1-8B-Stheno-v3.4"
DATASET_REPO = os.environ.get("HF_DATASET_REPO", "4moha/raunch-page-mode-v0")
MODEL_REPO = os.environ.get("HF_MODEL_REPO", "4moha/raunch-stheno-v3.4-lora-v0")
def main() -> None:
# Load model + tokenizer via Unsloth (faster + leaner than vanilla transformers)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL,
max_seq_length=4096,
dtype=None, # auto
load_in_4bit=True, # QLoRA β€” fits more comfortably on A10G
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=32,
lora_dropout=0,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
use_gradient_checkpointing="unsloth",
random_state=42,
)
# Load dataset, pre-render the chat template into a single "text" column,
# then split. Avoids version-skew on TRL/Unsloth's formatting_func contracts β€”
# SFTTrainer reads dataset_text_field="text" and tokenizes directly.
full = load_dataset(DATASET_REPO, data_files="train.jsonl", split="train")
def render_chat(example: dict) -> dict:
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
}
full = full.map(render_chat, remove_columns=["messages"])
split = full.train_test_split(test_size=0.05, seed=42)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=split["train"],
eval_dataset=split["test"],
args=SFTConfig(
dataset_text_field="text",
output_dir="raunch-stheno-v3.4-lora-v0",
push_to_hub=True,
hub_model_id=MODEL_REPO,
hub_private_repo=True,
hub_strategy="every_save",
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=5e-5,
lr_scheduler_type="cosine",
# warmup_ratio is deprecated in TRL 5.x β€” express as concrete steps instead.
# ~22 steps/epoch Γ— 3 epochs = ~65 steps; 5% warmup = ~3 steps.
warmup_steps=3,
max_length=4096,
logging_steps=5,
save_strategy="steps",
save_steps=200,
eval_strategy="steps",
eval_steps=50,
seed=42,
report_to="trackio",
run_name="raunch-stheno-v3.4-lora-v0",
project="raunch-page-mode",
),
)
trainer.train()
trainer.push_to_hub()
print("Training complete. LoRA pushed to:", MODEL_REPO)
if __name__ == "__main__":
main()