dm-lora-trainer / train_dm_lora.py
zprime's picture
Resume training from step 200 checkpoint
374cf10 verified
"""
Dungeon Master LoRA Training - Qwen3.5-9B via Unsloth
======================================================
Resumes from step 200 checkpoint.
Unsloth bf16 LoRA (no 4-bit quantization).
Hardware: L40S 1x (48GB VRAM, $1.80/hr)
"""
import os, sys, time, torch, threading
from http.server import HTTPServer, BaseHTTPRequestHandler
os.environ["PYTHONUNBUFFERED"] = "1"
# ============================================================
# Health check server on port 7860
# ============================================================
STATUS = {"stage": "starting", "step": 200, "total": 2563, "loss": 0.0, "t": time.time()}
class H(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
m = int(time.time() - STATUS["t"]) // 60
self.wfile.write(f"""<html><body style="font-family:monospace;padding:20px">
<h2>DM LoRA Training (resuming from step 200)</h2>
<p>Stage: {STATUS['stage']}</p>
<p>Step: {STATUS['step']}/{STATUS['total']}</p>
<p>Loss: {STATUS['loss']:.4f}</p>
<p>Elapsed: {m} min</p>
</body></html>""".encode())
def log_message(self, *a): pass
srv = HTTPServer(("0.0.0.0", 7860), H)
threading.Thread(target=srv.serve_forever, daemon=True).start()
print("Health check server on :7860", flush=True)
# ============================================================
# Auth
# ============================================================
from huggingface_hub import login, snapshot_download
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Logged in to HF Hub", flush=True)
else:
print("ERROR: No HF_TOKEN!", flush=True)
sys.exit(1)
# ============================================================
# Download checkpoint from Hub to resume
# ============================================================
STATUS["stage"] = "downloading checkpoint from Hub"
OUTPUT_REPO = "zprime/qwen3.5-9b-dungeon-master-lora"
CHECKPOINT_DIR = "/tmp/dm-lora/checkpoint-200"
print("Downloading step-200 checkpoint from Hub...", flush=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
# Download the last-checkpoint files into the checkpoint dir
snapshot_download(
repo_id=OUTPUT_REPO,
allow_patterns="last-checkpoint/*",
local_dir="/tmp/hub-checkpoint",
)
# Move files from last-checkpoint subfolder to checkpoint-200
import shutil
src = "/tmp/hub-checkpoint/last-checkpoint"
for f in os.listdir(src):
shutil.move(os.path.join(src, f), os.path.join(CHECKPOINT_DIR, f))
print(f"Checkpoint downloaded to {CHECKPOINT_DIR}", flush=True)
print(f"Files: {os.listdir(CHECKPOINT_DIR)}", flush=True)
# ============================================================
# Config
# ============================================================
MODEL_NAME = "unsloth/Qwen3.5-9B"
DATASET_ID = "chimbiwide/RolePlay-NPC-Quest"
MAX_SEQ_LENGTH = 2048
# ============================================================
# Load model via Unsloth
# ============================================================
STATUS["stage"] = "loading model via Unsloth"
print(f"Loading {MODEL_NAME} via Unsloth (bf16)...", flush=True)
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=torch.bfloat16,
load_in_4bit=False,
)
print("Model loaded via Unsloth", flush=True)
# ============================================================
# Add LoRA via Unsloth
# ============================================================
STATUS["stage"] = "adding LoRA"
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=32,
lora_dropout=0,
bias="none",
use_gradient_checkpointing=True,
random_state=42,
)
print("LoRA added: r=16, alpha=32", flush=True)
# ============================================================
# Trackio
# ============================================================
try:
import trackio
trackio.init(name="dm-lora-resume-200", project=OUTPUT_REPO)
print("Trackio enabled", flush=True)
REPORT_TO = "trackio"
except Exception as e:
print(f"Trackio warning: {e}", flush=True)
REPORT_TO = "none"
# ============================================================
# Load dataset
# ============================================================
STATUS["stage"] = "loading dataset"
print(f"Loading dataset: {DATASET_ID}", flush=True)
from datasets import load_dataset
dataset = load_dataset(DATASET_ID, split="train")
print(f"Dataset: {len(dataset)} examples", flush=True)
# ============================================================
# Formatting function
# ============================================================
def formatting_func(examples):
texts = []
for messages in examples["messages"]:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
return {"text": texts}
print("Formatting dataset with chat template...", flush=True)
dataset = dataset.map(formatting_func, batched=True, remove_columns=["messages"])
print(f"Dataset formatted: {len(dataset)} examples", flush=True)
# ============================================================
# Training config β€” same as before so resume works
# ============================================================
STATUS["stage"] = "initializing trainer"
from trl import SFTConfig, SFTTrainer
from transformers import TrainerCallback
training_args = SFTConfig(
output_dir="/tmp/dm-lora",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.01,
max_length=MAX_SEQ_LENGTH,
dataset_text_field="text",
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_strategy="steps",
logging_steps=5,
logging_first_step=True,
disable_tqdm=True,
report_to=REPORT_TO,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
hub_strategy="checkpoint",
seed=42,
dataloader_num_workers=2,
optim="adamw_8bit",
)
print("Initializing SFTTrainer...", flush=True)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
total_steps = 2563
STATUS["total"] = total_steps
print(f"Resuming training from step 200 / {total_steps}", flush=True)
print("=" * 60, flush=True)
# Status callback
class SC(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
STATUS["step"] = state.global_step
STATUS["loss"] = logs.get("loss", 0.0)
print(f"[Step {state.global_step}/{total_steps}] loss={logs.get('loss','?')}, lr={logs.get('learning_rate','?')}", flush=True)
trainer.add_callback(SC())
# ============================================================
# Resume training from checkpoint
# ============================================================
STATUS["stage"] = "training (resumed from step 200)"
print(f"Resuming from {CHECKPOINT_DIR}...", flush=True)
t0 = time.time()
trainer.train(resume_from_checkpoint=CHECKPOINT_DIR)
mins = (time.time() - t0) / 60
print(f"Training done in {mins:.1f} min!", flush=True)
# ============================================================
# Save & push
# ============================================================
STATUS["stage"] = "saving"
print("Saving final model...", flush=True)
trainer.save_model()
print("Pushing to Hub...", flush=True)
trainer.push_to_hub(commit_message="Dungeon Master LoRA - FINAL - Unsloth bf16 r=16")
print(f"DONE! https://huggingface.co/{OUTPUT_REPO}", flush=True)
STATUS["stage"] = "COMPLETE - SET HARDWARE TO CPU!"
print("=" * 60, flush=True)
print("TRAINING COMPLETE!", flush=True)
print(f"Adapter: https://huggingface.co/{OUTPUT_REPO}", flush=True)
print("GO TO SETTINGS -> SET HARDWARE TO CPU TO STOP BILLING!", flush=True)
print("=" * 60, flush=True)
srv.serve_forever()