Spaces:
Runtime error
Runtime error
| """ | |
| Dungeon Master LoRA Training - Qwen3.5-9B via Unsloth | |
| ====================================================== | |
| Resumes from step 200 checkpoint. | |
| Unsloth bf16 LoRA (no 4-bit quantization). | |
| Hardware: L40S 1x (48GB VRAM, $1.80/hr) | |
| """ | |
| import os, sys, time, torch, threading | |
| from http.server import HTTPServer, BaseHTTPRequestHandler | |
| os.environ["PYTHONUNBUFFERED"] = "1" | |
| # ============================================================ | |
| # Health check server on port 7860 | |
| # ============================================================ | |
| STATUS = {"stage": "starting", "step": 200, "total": 2563, "loss": 0.0, "t": time.time()} | |
| class H(BaseHTTPRequestHandler): | |
| def do_GET(self): | |
| self.send_response(200) | |
| self.send_header("Content-Type", "text/html") | |
| self.end_headers() | |
| m = int(time.time() - STATUS["t"]) // 60 | |
| self.wfile.write(f"""<html><body style="font-family:monospace;padding:20px"> | |
| <h2>DM LoRA Training (resuming from step 200)</h2> | |
| <p>Stage: {STATUS['stage']}</p> | |
| <p>Step: {STATUS['step']}/{STATUS['total']}</p> | |
| <p>Loss: {STATUS['loss']:.4f}</p> | |
| <p>Elapsed: {m} min</p> | |
| </body></html>""".encode()) | |
| def log_message(self, *a): pass | |
| srv = HTTPServer(("0.0.0.0", 7860), H) | |
| threading.Thread(target=srv.serve_forever, daemon=True).start() | |
| print("Health check server on :7860", flush=True) | |
| # ============================================================ | |
| # Auth | |
| # ============================================================ | |
| from huggingface_hub import login, snapshot_download | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| print("Logged in to HF Hub", flush=True) | |
| else: | |
| print("ERROR: No HF_TOKEN!", flush=True) | |
| sys.exit(1) | |
| # ============================================================ | |
| # Download checkpoint from Hub to resume | |
| # ============================================================ | |
| STATUS["stage"] = "downloading checkpoint from Hub" | |
| OUTPUT_REPO = "zprime/qwen3.5-9b-dungeon-master-lora" | |
| CHECKPOINT_DIR = "/tmp/dm-lora/checkpoint-200" | |
| print("Downloading step-200 checkpoint from Hub...", flush=True) | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| # Download the last-checkpoint files into the checkpoint dir | |
| snapshot_download( | |
| repo_id=OUTPUT_REPO, | |
| allow_patterns="last-checkpoint/*", | |
| local_dir="/tmp/hub-checkpoint", | |
| ) | |
| # Move files from last-checkpoint subfolder to checkpoint-200 | |
| import shutil | |
| src = "/tmp/hub-checkpoint/last-checkpoint" | |
| for f in os.listdir(src): | |
| shutil.move(os.path.join(src, f), os.path.join(CHECKPOINT_DIR, f)) | |
| print(f"Checkpoint downloaded to {CHECKPOINT_DIR}", flush=True) | |
| print(f"Files: {os.listdir(CHECKPOINT_DIR)}", flush=True) | |
| # ============================================================ | |
| # Config | |
| # ============================================================ | |
| MODEL_NAME = "unsloth/Qwen3.5-9B" | |
| DATASET_ID = "chimbiwide/RolePlay-NPC-Quest" | |
| MAX_SEQ_LENGTH = 2048 | |
| # ============================================================ | |
| # Load model via Unsloth | |
| # ============================================================ | |
| STATUS["stage"] = "loading model via Unsloth" | |
| print(f"Loading {MODEL_NAME} via Unsloth (bf16)...", flush=True) | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_NAME, | |
| max_seq_length=MAX_SEQ_LENGTH, | |
| dtype=torch.bfloat16, | |
| load_in_4bit=False, | |
| ) | |
| print("Model loaded via Unsloth", flush=True) | |
| # ============================================================ | |
| # Add LoRA via Unsloth | |
| # ============================================================ | |
| STATUS["stage"] = "adding LoRA" | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=32, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing=True, | |
| random_state=42, | |
| ) | |
| print("LoRA added: r=16, alpha=32", flush=True) | |
| # ============================================================ | |
| # Trackio | |
| # ============================================================ | |
| try: | |
| import trackio | |
| trackio.init(name="dm-lora-resume-200", project=OUTPUT_REPO) | |
| print("Trackio enabled", flush=True) | |
| REPORT_TO = "trackio" | |
| except Exception as e: | |
| print(f"Trackio warning: {e}", flush=True) | |
| REPORT_TO = "none" | |
| # ============================================================ | |
| # Load dataset | |
| # ============================================================ | |
| STATUS["stage"] = "loading dataset" | |
| print(f"Loading dataset: {DATASET_ID}", flush=True) | |
| from datasets import load_dataset | |
| dataset = load_dataset(DATASET_ID, split="train") | |
| print(f"Dataset: {len(dataset)} examples", flush=True) | |
| # ============================================================ | |
| # Formatting function | |
| # ============================================================ | |
| def formatting_func(examples): | |
| texts = [] | |
| for messages in examples["messages"]: | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=False | |
| ) | |
| texts.append(text) | |
| return {"text": texts} | |
| print("Formatting dataset with chat template...", flush=True) | |
| dataset = dataset.map(formatting_func, batched=True, remove_columns=["messages"]) | |
| print(f"Dataset formatted: {len(dataset)} examples", flush=True) | |
| # ============================================================ | |
| # Training config β same as before so resume works | |
| # ============================================================ | |
| STATUS["stage"] = "initializing trainer" | |
| from trl import SFTConfig, SFTTrainer | |
| from transformers import TrainerCallback | |
| training_args = SFTConfig( | |
| output_dir="/tmp/dm-lora", | |
| num_train_epochs=1, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| lr_scheduler_type="cosine", | |
| warmup_steps=100, | |
| weight_decay=0.01, | |
| max_length=MAX_SEQ_LENGTH, | |
| dataset_text_field="text", | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| logging_strategy="steps", | |
| logging_steps=5, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to=REPORT_TO, | |
| save_strategy="steps", | |
| save_steps=200, | |
| save_total_limit=3, | |
| push_to_hub=True, | |
| hub_model_id=OUTPUT_REPO, | |
| hub_strategy="checkpoint", | |
| seed=42, | |
| dataloader_num_workers=2, | |
| optim="adamw_8bit", | |
| ) | |
| print("Initializing SFTTrainer...", flush=True) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| ) | |
| total_steps = 2563 | |
| STATUS["total"] = total_steps | |
| print(f"Resuming training from step 200 / {total_steps}", flush=True) | |
| print("=" * 60, flush=True) | |
| # Status callback | |
| class SC(TrainerCallback): | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs: | |
| STATUS["step"] = state.global_step | |
| STATUS["loss"] = logs.get("loss", 0.0) | |
| print(f"[Step {state.global_step}/{total_steps}] loss={logs.get('loss','?')}, lr={logs.get('learning_rate','?')}", flush=True) | |
| trainer.add_callback(SC()) | |
| # ============================================================ | |
| # Resume training from checkpoint | |
| # ============================================================ | |
| STATUS["stage"] = "training (resumed from step 200)" | |
| print(f"Resuming from {CHECKPOINT_DIR}...", flush=True) | |
| t0 = time.time() | |
| trainer.train(resume_from_checkpoint=CHECKPOINT_DIR) | |
| mins = (time.time() - t0) / 60 | |
| print(f"Training done in {mins:.1f} min!", flush=True) | |
| # ============================================================ | |
| # Save & push | |
| # ============================================================ | |
| STATUS["stage"] = "saving" | |
| print("Saving final model...", flush=True) | |
| trainer.save_model() | |
| print("Pushing to Hub...", flush=True) | |
| trainer.push_to_hub(commit_message="Dungeon Master LoRA - FINAL - Unsloth bf16 r=16") | |
| print(f"DONE! https://huggingface.co/{OUTPUT_REPO}", flush=True) | |
| STATUS["stage"] = "COMPLETE - SET HARDWARE TO CPU!" | |
| print("=" * 60, flush=True) | |
| print("TRAINING COMPLETE!", flush=True) | |
| print(f"Adapter: https://huggingface.co/{OUTPUT_REPO}", flush=True) | |
| print("GO TO SETTINGS -> SET HARDWARE TO CPU TO STOP BILLING!", flush=True) | |
| print("=" * 60, flush=True) | |
| srv.serve_forever() | |