File size: 8,324 Bytes
9bde38f
c7fd3ee
 
374cf10
c7fd3ee
70d4b89
9bde38f
c7fd3ee
6fb4c7e
9bde38f
dde34f4
 
ef53e9f
374cf10
6fb4c7e
374cf10
6fb4c7e
c7fd3ee
6fb4c7e
 
dde34f4
6fb4c7e
c7fd3ee
 
374cf10
c7fd3ee
 
 
 
 
 
 
 
 
 
6fb4c7e
 
 
ef53e9f
374cf10
9bde38f
 
 
dde34f4
ef53e9f
c7fd3ee
ef53e9f
9bde38f
374cf10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7fd3ee
 
 
5a15785
9bde38f
c7fd3ee
 
 
 
 
 
5a15785
c7fd3ee
 
 
 
 
 
 
5a15785
c7fd3ee
5a15785
c7fd3ee
 
 
 
 
 
 
 
 
 
 
 
 
5a15785
c7fd3ee
 
 
 
 
9bde38f
ef53e9f
dde34f4
ef53e9f
9bde38f
 
374cf10
dde34f4
9bde38f
 
dde34f4
9bde38f
 
ef53e9f
c7fd3ee
ef53e9f
6fb4c7e
c7fd3ee
 
9bde38f
c7fd3ee
9bde38f
5a15785
374cf10
5a15785
 
 
 
 
 
 
 
 
 
 
 
 
 
ef53e9f
374cf10
ef53e9f
6fb4c7e
c7fd3ee
 
dde34f4
9bde38f
c7fd3ee
dde34f4
c7fd3ee
 
e146c91
 
c7fd3ee
e146c91
c7fd3ee
5a15785
e146c91
 
9bde38f
e146c91
dde34f4
e146c91
 
 
 
c7fd3ee
dde34f4
e146c91
 
 
 
c7fd3ee
70d4b89
9bde38f
 
dde34f4
9bde38f
e146c91
 
 
 
9bde38f
 
374cf10
c7fd3ee
374cf10
dde34f4
 
c7fd3ee
 
dde34f4
 
 
 
c7fd3ee
9bde38f
c7fd3ee
dde34f4
 
374cf10
dde34f4
374cf10
 
c7fd3ee
dde34f4
374cf10
6fb4c7e
c7fd3ee
 
dde34f4
 
c7fd3ee
dde34f4
c7fd3ee
dde34f4
9bde38f
dde34f4
374cf10
c7fd3ee
dde34f4
c7fd3ee
dde34f4
 
 
c7fd3ee
dde34f4
c7fd3ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Dungeon Master LoRA Training - Qwen3.5-9B via Unsloth
======================================================
Resumes from step 200 checkpoint.
Unsloth bf16 LoRA (no 4-bit quantization).
Hardware: L40S 1x (48GB VRAM, $1.80/hr)
"""
import os, sys, time, torch, threading
from http.server import HTTPServer, BaseHTTPRequestHandler

os.environ["PYTHONUNBUFFERED"] = "1"

# ============================================================
# Health check server on port 7860
# ============================================================
STATUS = {"stage": "starting", "step": 200, "total": 2563, "loss": 0.0, "t": time.time()}

class H(BaseHTTPRequestHandler):
    def do_GET(self):
        self.send_response(200)
        self.send_header("Content-Type", "text/html")
        self.end_headers()
        m = int(time.time() - STATUS["t"]) // 60
        self.wfile.write(f"""<html><body style="font-family:monospace;padding:20px">
        <h2>DM LoRA Training (resuming from step 200)</h2>
        <p>Stage: {STATUS['stage']}</p>
        <p>Step: {STATUS['step']}/{STATUS['total']}</p>
        <p>Loss: {STATUS['loss']:.4f}</p>
        <p>Elapsed: {m} min</p>
        </body></html>""".encode())
    def log_message(self, *a): pass

srv = HTTPServer(("0.0.0.0", 7860), H)
threading.Thread(target=srv.serve_forever, daemon=True).start()
print("Health check server on :7860", flush=True)

# ============================================================
# Auth
# ============================================================
from huggingface_hub import login, snapshot_download
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("Logged in to HF Hub", flush=True)
else:
    print("ERROR: No HF_TOKEN!", flush=True)
    sys.exit(1)

# ============================================================
# Download checkpoint from Hub to resume
# ============================================================
STATUS["stage"] = "downloading checkpoint from Hub"
OUTPUT_REPO = "zprime/qwen3.5-9b-dungeon-master-lora"
CHECKPOINT_DIR = "/tmp/dm-lora/checkpoint-200"

print("Downloading step-200 checkpoint from Hub...", flush=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Download the last-checkpoint files into the checkpoint dir
snapshot_download(
    repo_id=OUTPUT_REPO,
    allow_patterns="last-checkpoint/*",
    local_dir="/tmp/hub-checkpoint",
)
# Move files from last-checkpoint subfolder to checkpoint-200
import shutil
src = "/tmp/hub-checkpoint/last-checkpoint"
for f in os.listdir(src):
    shutil.move(os.path.join(src, f), os.path.join(CHECKPOINT_DIR, f))
print(f"Checkpoint downloaded to {CHECKPOINT_DIR}", flush=True)
print(f"Files: {os.listdir(CHECKPOINT_DIR)}", flush=True)

# ============================================================
# Config
# ============================================================
MODEL_NAME = "unsloth/Qwen3.5-9B"
DATASET_ID = "chimbiwide/RolePlay-NPC-Quest"
MAX_SEQ_LENGTH = 2048

# ============================================================
# Load model via Unsloth
# ============================================================
STATUS["stage"] = "loading model via Unsloth"
print(f"Loading {MODEL_NAME} via Unsloth (bf16)...", flush=True)

from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=torch.bfloat16,
    load_in_4bit=False,
)
print("Model loaded via Unsloth", flush=True)

# ============================================================
# Add LoRA via Unsloth
# ============================================================
STATUS["stage"] = "adding LoRA"
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=32,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=42,
)
print("LoRA added: r=16, alpha=32", flush=True)

# ============================================================
# Trackio
# ============================================================
try:
    import trackio
    trackio.init(name="dm-lora-resume-200", project=OUTPUT_REPO)
    print("Trackio enabled", flush=True)
    REPORT_TO = "trackio"
except Exception as e:
    print(f"Trackio warning: {e}", flush=True)
    REPORT_TO = "none"

# ============================================================
# Load dataset
# ============================================================
STATUS["stage"] = "loading dataset"
print(f"Loading dataset: {DATASET_ID}", flush=True)
from datasets import load_dataset
dataset = load_dataset(DATASET_ID, split="train")
print(f"Dataset: {len(dataset)} examples", flush=True)

# ============================================================
# Formatting function
# ============================================================
def formatting_func(examples):
    texts = []
    for messages in examples["messages"]:
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )
        texts.append(text)
    return {"text": texts}

print("Formatting dataset with chat template...", flush=True)
dataset = dataset.map(formatting_func, batched=True, remove_columns=["messages"])
print(f"Dataset formatted: {len(dataset)} examples", flush=True)

# ============================================================
# Training config β€” same as before so resume works
# ============================================================
STATUS["stage"] = "initializing trainer"
from trl import SFTConfig, SFTTrainer
from transformers import TrainerCallback

training_args = SFTConfig(
    output_dir="/tmp/dm-lora",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.01,
    max_length=MAX_SEQ_LENGTH,
    dataset_text_field="text",
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    logging_strategy="steps",
    logging_steps=5,
    logging_first_step=True,
    disable_tqdm=True,
    report_to=REPORT_TO,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    push_to_hub=True,
    hub_model_id=OUTPUT_REPO,
    hub_strategy="checkpoint",
    seed=42,
    dataloader_num_workers=2,
    optim="adamw_8bit",
)

print("Initializing SFTTrainer...", flush=True)
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer,
)

total_steps = 2563
STATUS["total"] = total_steps
print(f"Resuming training from step 200 / {total_steps}", flush=True)
print("=" * 60, flush=True)

# Status callback
class SC(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            STATUS["step"] = state.global_step
            STATUS["loss"] = logs.get("loss", 0.0)
            print(f"[Step {state.global_step}/{total_steps}] loss={logs.get('loss','?')}, lr={logs.get('learning_rate','?')}", flush=True)

trainer.add_callback(SC())

# ============================================================
# Resume training from checkpoint
# ============================================================
STATUS["stage"] = "training (resumed from step 200)"
print(f"Resuming from {CHECKPOINT_DIR}...", flush=True)
t0 = time.time()

trainer.train(resume_from_checkpoint=CHECKPOINT_DIR)

mins = (time.time() - t0) / 60
print(f"Training done in {mins:.1f} min!", flush=True)

# ============================================================
# Save & push
# ============================================================
STATUS["stage"] = "saving"
print("Saving final model...", flush=True)
trainer.save_model()
print("Pushing to Hub...", flush=True)
trainer.push_to_hub(commit_message="Dungeon Master LoRA - FINAL - Unsloth bf16 r=16")
print(f"DONE! https://huggingface.co/{OUTPUT_REPO}", flush=True)

STATUS["stage"] = "COMPLETE - SET HARDWARE TO CPU!"
print("=" * 60, flush=True)
print("TRAINING COMPLETE!", flush=True)
print(f"Adapter: https://huggingface.co/{OUTPUT_REPO}", flush=True)
print("GO TO SETTINGS -> SET HARDWARE TO CPU TO STOP BILLING!", flush=True)
print("=" * 60, flush=True)
srv.serve_forever()