AniFileBERT / colab_train.py
ModerRAS's picture
colab: onnxε―Όε‡Ίζ”ΉδΈΊιžι˜»ε‘ž+θ‘₯ε…¨onnxscript
e34dc04
raw
history blame
5.26 kB
# -*- coding: utf-8 -*-
"""AniFileBERT β€” Google Colab Training Script
=============================================
How to use:
1. Open https://colab.research.google.com/
2. File β†’ Upload notebook β†’ select this file, OR
Copy the entire content into a new code cell
3. Runtime β†’ Change runtime type β†’ T4 GPU
4. Run all
What it does:
- Mounts Google Drive (for persistent checkpoints)
- Clones AniFileBERT repo + AnimeName dataset submodule
- Installs PyTorch + Transformers dependencies
- Runs training: fine-tune from current checkpoint with 8000-token vocab
- Saves final model to Drive
Output:
- Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
- Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-finetune/final/
"""
import os
import sys
import subprocess
import time
def run(cmd, echo=True):
"""Run a shell command and print output in real time."""
if echo:
print(f"\n$ {cmd}")
proc = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1
)
for line in proc.stdout:
print(line, end="")
proc.wait()
if proc.returncode != 0:
raise RuntimeError(f"Command failed (exit code {proc.returncode}): {cmd}")
return proc.returncode
# ── 1. Mount Google Drive ──────────────────────────────────────
print("=" * 60)
print("STEP 1: Mount Google Drive")
print("=" * 60)
from google.colab import drive
drive.mount("/content/drive")
DRIVE_ROOT = "/content/drive/MyDrive/AniFileBERT"
os.makedirs(DRIVE_ROOT, exist_ok=True)
print(f"Checkpoints will be saved to: {DRIVE_ROOT}")
# ── 2. Clone repositories ──────────────────────────────────────
print("\n" + "=" * 60)
print("STEP 2: Clone AniFileBERT repository")
print("=" * 60)
REPO_DIR = "/content/AniFileBERT"
if not os.path.isdir(REPO_DIR):
os.chdir("/content")
run("git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT")
else:
print("Repository already exists, pulling latest...")
os.chdir(REPO_DIR)
run("git pull")
run("git submodule update --init --recursive")
os.chdir(REPO_DIR)
# ── 3. Install dependencies ────────────────────────────────────
print("\n" + "=" * 60)
print("STEP 3: Install dependencies")
print("=" * 60)
# Colab comes with PyTorch + CUDA pre-installed. Just install the extras.
run("pip install transformers accelerate seqeval onnx onnxruntime onnxscript")
# ── 4. Verify GPU ──────────────────────────────────────────────
print("\n" + "=" * 60)
print("STEP 4: Verify GPU")
print("=" * 60)
run("nvidia-smi 2>/dev/null || echo 'No GPU found β€” training will be slow on CPU'")
# Single-quote the shell command to avoid bash expanding {torch...}
run("python -c 'import torch; print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")'")
# ── 5. Verify vocab ────────────────────────────────────────────
print("\n" + "=" * 60)
print("STEP 5: Verify vocabulary")
print("=" * 60)
run("python -c 'import json; v=json.load(open(\"vocab.json\")); print(f\"Vocab size: {len(v)} tokens\")'")
# ── 6. Run training ────────────────────────────────────────────
print("\n" + "=" * 60)
print("STEP 6: Train model")
print("=" * 60)
# The 8000-token vocab is already in datasets/AnimeName/vocab.json.
# The old checkpoint (3000-token embedding) gets resized automatically.
SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-finetune")
run(
f"python train.py "
f"--data-file datasets/AnimeName/dmhy_weak.jsonl "
f"--vocab-file datasets/AnimeName/vocab.json "
f"--save-dir {SAVE_DIR} "
f"--init-model-dir . "
f"--epochs 10 --batch-size 128 "
f"--learning-rate 0.0003 --warmup-steps 300 "
f"--seed 42 "
f"--no-shuffle"
)
# ── 7. Export ONNX (optional) ──────────────────────────────────
print("\n" + "=" * 60)
print("STEP 7: Export ONNX (optional β€” skip if it fails)")
print("=" * 60)
ONNX_OUT = os.path.join(SAVE_DIR, "..", "anime_filename_parser.onnx")
try:
run(
f"python export_onnx.py "
f"--model-dir {SAVE_DIR}/final "
f"--output {ONNX_OUT}"
)
except Exception as e:
print(f"[WARN] ONNX export skipped: {e}")
# ── 8. Summary ─────────────────────────────────────────────────
print("\n" + "=" * 60)
print("DONE!")
print("=" * 60)
print(f"\nCheckpoints: {SAVE_DIR}/")
print(f"Final model: {SAVE_DIR}/final/")
print(f"ONNX export: {ONNX_OUT}")
print(f"\nAll files are on Google Drive β€” they persist across Colab sessions.")
print(f"You can also download them from the Drive web UI.")