RayMelius Claude Sonnet 4.6 commited on
Commit
ff6464a
Β·
1 Parent(s): ebf88a6

Add resumable training with Google Drive + HF Hub checkpoints

Browse files

- Mounts Google Drive; checkpoints saved to MyDrive/stockex-ch-checkpoints
- Auto-detects latest checkpoint and resumes with resume_from_checkpoint
- CheckpointSyncCallback: on every save β†’ copies to Drive + pushes adapter to HF Hub
- save_total_limit=3 to keep local disk usage bounded
- Falls back gracefully if Drive not available

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. notebooks/ch_trader_finetune.ipynb +15 -1
notebooks/ch_trader_finetune.ipynb CHANGED
@@ -76,6 +76,20 @@
76
  "- Target: a valid JSON trading decision that respects all constraints"
77
  ]
78
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  {
80
  "cell_type": "code",
81
  "execution_count": null,
@@ -398,7 +412,7 @@
398
  "id": "train",
399
  "metadata": {},
400
  "outputs": [],
401
- "source": "sft_config = SFTConfig(\n output_dir=OUTPUT_DIR,\n num_train_epochs=NUM_EPOCHS,\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n gradient_checkpointing=True,\n optim=\"paged_adamw_32bit\",\n learning_rate=LR,\n lr_scheduler_type=\"cosine\",\n warmup_ratio=0.05,\n fp16=not torch.cuda.is_bf16_supported(),\n bf16=torch.cuda.is_bf16_supported(),\n logging_steps=25,\n eval_strategy=\"steps\",\n eval_steps=100,\n save_strategy=\"steps\",\n save_steps=100,\n load_best_model_at_end=True,\n metric_for_best_model=\"eval_loss\",\n greater_is_better=False,\n report_to=\"none\",\n dataset_text_field=\"text\",\n packing=False,\n)\n\ntrainer = SFTTrainer(\n model=model,\n args=sft_config,\n train_dataset=train_dataset,\n eval_dataset=val_dataset,\n peft_config=lora_config,\n processing_class=tokenizer,\n)\n\nprint(\"Starting training...\")\ntrainer.train()\nprint(\"Training complete.\")"
402
  },
403
  {
404
  "cell_type": "markdown",
 
76
  "- Target: a valid JSON trading decision that respects all constraints"
77
  ]
78
  },
79
+ {
80
+ "cell_type": "markdown",
81
+ "id": "g7cean6ejyj",
82
+ "source": "## 1b. Checkpoint Setup (Resumable Training)\n\nCheckpoints are saved to **Google Drive** so they survive Colab session disconnects.\n- If a checkpoint is found, training resumes automatically from where it left off.\n- Adapters are also pushed to HF Hub every `SAVE_STEPS` steps as a backup.",
83
+ "metadata": {}
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "id": "egq0dp9csuo",
88
+ "source": "import shutil\nfrom transformers.trainer_utils import get_last_checkpoint\n\n# ── Mount Google Drive for persistent checkpoint storage ──────────────────────\ntry:\n from google.colab import drive\n drive.mount(\"/content/drive\", force_remount=False)\n DRIVE_CKPT_DIR = f\"/content/drive/MyDrive/stockex-ch-checkpoints\"\n USE_DRIVE = True\n print(f\"Google Drive mounted. Checkpoints β†’ {DRIVE_CKPT_DIR}\")\nexcept Exception:\n DRIVE_CKPT_DIR = None\n USE_DRIVE = False\n print(\"Google Drive not available β€” checkpoints saved locally only.\")\n\n# OUTPUT_DIR is the local working copy; DRIVE_CKPT_DIR is the persistent store\nos.makedirs(OUTPUT_DIR, exist_ok=True)\nif USE_DRIVE:\n os.makedirs(DRIVE_CKPT_DIR, exist_ok=True)\n # Restore latest checkpoint from Drive to local dir (if any)\n drive_ckpt = get_last_checkpoint(DRIVE_CKPT_DIR)\n if drive_ckpt:\n local_ckpt_name = os.path.basename(drive_ckpt)\n local_ckpt_path = os.path.join(OUTPUT_DIR, local_ckpt_name)\n if not os.path.exists(local_ckpt_path):\n print(f\"Restoring checkpoint from Drive: {drive_ckpt}\")\n shutil.copytree(drive_ckpt, local_ckpt_path)\n else:\n print(f\"Checkpoint already in local dir: {local_ckpt_path}\")\n\n# ── Detect latest local checkpoint ────────────────────────────────────────────\nRESUME_FROM = get_last_checkpoint(OUTPUT_DIR)\nif RESUME_FROM:\n print(f\"βœ“ Will resume training from: {RESUME_FROM}\")\nelse:\n print(\"No checkpoint found β€” starting fresh.\")\n\nSAVE_STEPS = 50 # save every N steps (also pushed to Drive + HF Hub)",
89
+ "metadata": {},
90
+ "execution_count": null,
91
+ "outputs": []
92
+ },
93
  {
94
  "cell_type": "code",
95
  "execution_count": null,
 
412
  "id": "train",
413
  "metadata": {},
414
  "outputs": [],
415
+ "source": "from transformers import TrainerCallback\n\nclass CheckpointSyncCallback(TrainerCallback):\n \"\"\"After every checkpoint: copy to Google Drive and push adapter to HF Hub.\"\"\"\n\n def on_save(self, args, state, control, **kwargs):\n ckpt_dir = os.path.join(args.output_dir, f\"checkpoint-{state.global_step}\")\n if not os.path.isdir(ckpt_dir):\n return\n\n # 1. Sync to Google Drive\n if USE_DRIVE and DRIVE_CKPT_DIR:\n dest = os.path.join(DRIVE_CKPT_DIR, f\"checkpoint-{state.global_step}\")\n if not os.path.exists(dest):\n shutil.copytree(ckpt_dir, dest)\n print(f\"[Checkpoint] Saved to Drive: {dest}\")\n\n # 2. Push adapter to HF Hub (lightweight β€” only LoRA weights)\n try:\n kwargs[\"model\"].push_to_hub(\n OUTPUT_REPO,\n commit_message=f\"Checkpoint step {state.global_step} \"\n f\"(epoch {state.epoch:.2f})\",\n token=HF_TOKEN,\n )\n print(f\"[Checkpoint] Pushed adapter step {state.global_step} β†’ HF Hub\")\n except Exception as e:\n print(f\"[Checkpoint] HF Hub push failed (non-fatal): {e}\")\n\n\nsft_config = SFTConfig(\n output_dir=OUTPUT_DIR,\n num_train_epochs=NUM_EPOCHS,\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n gradient_checkpointing=True,\n optim=\"paged_adamw_32bit\",\n learning_rate=LR,\n lr_scheduler_type=\"cosine\",\n warmup_ratio=0.05,\n fp16=not torch.cuda.is_bf16_supported(),\n bf16=torch.cuda.is_bf16_supported(),\n logging_steps=10,\n eval_strategy=\"steps\",\n eval_steps=SAVE_STEPS,\n save_strategy=\"steps\",\n save_steps=SAVE_STEPS,\n save_total_limit=3, # keep only 3 latest local checkpoints\n load_best_model_at_end=True,\n metric_for_best_model=\"eval_loss\",\n greater_is_better=False,\n report_to=\"none\",\n dataset_text_field=\"text\",\n packing=False,\n)\n\ntrainer = SFTTrainer(\n model=model,\n args=sft_config,\n train_dataset=train_dataset,\n eval_dataset=val_dataset,\n peft_config=lora_config,\n processing_class=tokenizer,\n callbacks=[CheckpointSyncCallback()],\n)\n\nif RESUME_FROM:\n print(f\"β–Ά Resuming from checkpoint: {RESUME_FROM}\")\nelse:\n print(\"β–Ά Starting training from scratch\")\n\ntrainer.train(resume_from_checkpoint=RESUME_FROM)\nprint(\"βœ“ Training complete.\")"
416
  },
417
  {
418
  "cell_type": "markdown",