Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
Add Codex Colab training workflow
Browse files- AGENTS.md +60 -0
- README.md +11 -3
- colab/README.md +75 -0
- colab/configs/dmhy_char_train.json +42 -0
- colab/configs/dmhy_regex_finetune.json +42 -0
- colab/start_worker.ipynb +45 -0
- colab_client.py +184 -0
- colab_train.py +526 -122
- colab_worker.py +446 -0
- train.py +165 -17
AGENTS.md
CHANGED
|
@@ -67,6 +67,66 @@ Export for Android:
|
|
| 67 |
python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --android-assets-dir ../../scraper/src/main/assets/anime_parser
|
| 68 |
```
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
## Validation Expectations
|
| 71 |
|
| 72 |
- For parser or tokenizer changes, run `python inference.py --model-dir . ...`
|
|
|
|
| 67 |
python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --android-assets-dir ../../scraper/src/main/assets/anime_parser
|
| 68 |
```
|
| 69 |
|
| 70 |
+
## Codex-Controlled Colab Training
|
| 71 |
+
|
| 72 |
+
Free Colab cannot be treated as an always-on remote machine. Use it as a
|
| 73 |
+
short-lived GPU worker only after the user manually opens a Colab runtime and
|
| 74 |
+
starts the worker cell. Do not assume Codex can wake Colab by itself.
|
| 75 |
+
|
| 76 |
+
Before relying on the Colab flow, make sure the Colab helper files have been
|
| 77 |
+
pushed to the Hugging Face model repo, or the user has uploaded them manually:
|
| 78 |
+
`colab_worker.py`, `colab_client.py`, `colab_train.py`, and `colab/`.
|
| 79 |
+
|
| 80 |
+
Ask the user to start a Colab GPU runtime with:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
from google.colab import drive
|
| 84 |
+
drive.mount("/content/drive")
|
| 85 |
+
|
| 86 |
+
!git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true
|
| 87 |
+
%cd /content/AniFileBERT
|
| 88 |
+
!git pull --ff-only || true
|
| 89 |
+
!git submodule update --init --recursive
|
| 90 |
+
!python colab_worker.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
The worker prints `COLAB_WORKER_URL=...` and `COLAB_WORKER_TOKEN=...`. After
|
| 94 |
+
the user provides those values, set them for local commands:
|
| 95 |
+
|
| 96 |
+
```powershell
|
| 97 |
+
$env:ANIFILEBERT_COLAB_URL="https://...trycloudflare.com"
|
| 98 |
+
$env:ANIFILEBERT_COLAB_TOKEN="..."
|
| 99 |
+
python colab_client.py health
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Submit the default regex fine-tune:
|
| 103 |
+
|
| 104 |
+
```powershell
|
| 105 |
+
python colab_client.py submit --profile dmhy_regex_finetune --wait
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
Submit the character tokenizer run only when intentional:
|
| 109 |
+
|
| 110 |
+
```powershell
|
| 111 |
+
python colab_client.py submit --profile dmhy_char_train --wait
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
Useful follow-up commands:
|
| 115 |
+
|
| 116 |
+
```powershell
|
| 117 |
+
python colab_client.py jobs
|
| 118 |
+
python colab_client.py status <job-id>
|
| 119 |
+
python colab_client.py logs <job-id> --tail 200
|
| 120 |
+
python colab_client.py manifest <job-id>
|
| 121 |
+
python colab_client.py cancel <job-id>
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
The default Colab profiles save checkpoints to Google Drive every 1000 steps
|
| 125 |
+
and resume with `resume_from_checkpoint: "auto"`, so if free Colab disconnects,
|
| 126 |
+
ask the user to restart the worker and submit the same profile again. Artifacts
|
| 127 |
+
land under `MyDrive/AniFileBERT/checkpoints/<profile-name>/`, and worker logs
|
| 128 |
+
land under `MyDrive/AniFileBERT/worker/jobs/<job-id>/`.
|
| 129 |
+
|
| 130 |
## Validation Expectations
|
| 131 |
|
| 132 |
- For parser or tokenizer changes, run `python inference.py --model-dir . ...`
|
README.md
CHANGED
|
@@ -199,9 +199,17 @@ python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --output expor
|
|
| 199 |
|
| 200 |
## Google Colab Training
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
## Repository Layout
|
| 207 |
|
|
|
|
| 199 |
|
| 200 |
## Google Colab Training
|
| 201 |
|
| 202 |
+
For Codex-controlled short Colab sessions, see [`colab/README.md`](colab/README.md).
|
| 203 |
+
Free Colab still has to be started manually, but once `colab_worker.py` is
|
| 204 |
+
running Codex can submit jobs through `colab_client.py`, tail logs, and inspect
|
| 205 |
+
status. Checkpoints live on Google Drive and default profiles resume from the
|
| 206 |
+
latest checkpoint automatically.
|
| 207 |
+
|
| 208 |
+
Manual one-shot runs are also supported:
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
python colab_train.py --profile dmhy_regex_finetune
|
| 212 |
+
```
|
| 213 |
|
| 214 |
## Repository Layout
|
| 215 |
|
colab/README.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Codex + Colab Training
|
| 2 |
+
|
| 3 |
+
Free Colab cannot be used as an always-on remote machine. The practical setup is:
|
| 4 |
+
|
| 5 |
+
1. Open a Colab GPU runtime when you want to train.
|
| 6 |
+
2. Start the lightweight worker in one cell.
|
| 7 |
+
3. Give Codex the printed worker URL and token.
|
| 8 |
+
4. Codex submits jobs while that Colab session is alive.
|
| 9 |
+
5. Checkpoints and manifests stay on Google Drive, so the next session can resume.
|
| 10 |
+
|
| 11 |
+
## Start a Colab Session
|
| 12 |
+
|
| 13 |
+
Run this in a Colab code cell:
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from google.colab import drive
|
| 17 |
+
drive.mount("/content/drive")
|
| 18 |
+
|
| 19 |
+
!git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true
|
| 20 |
+
%cd /content/AniFileBERT
|
| 21 |
+
!git pull --ff-only || true
|
| 22 |
+
!git submodule update --init --recursive
|
| 23 |
+
!python colab_worker.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
The cell prints:
|
| 27 |
+
|
| 28 |
+
```text
|
| 29 |
+
COLAB_WORKER_URL=https://...trycloudflare.com
|
| 30 |
+
COLAB_WORKER_TOKEN=...
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Keep that cell running. If Colab disconnects, start it again; default profiles
|
| 34 |
+
save every 1000 steps and resume from the latest Drive checkpoint because they
|
| 35 |
+
use `checkpoint_steps: 1000` and `resume_from_checkpoint: "auto"`.
|
| 36 |
+
|
| 37 |
+
## Let Codex Submit a Job
|
| 38 |
+
|
| 39 |
+
On the local machine:
|
| 40 |
+
|
| 41 |
+
```powershell
|
| 42 |
+
$env:ANIFILEBERT_COLAB_URL="https://...trycloudflare.com"
|
| 43 |
+
$env:ANIFILEBERT_COLAB_TOKEN="..."
|
| 44 |
+
python colab_client.py health
|
| 45 |
+
python colab_client.py submit --profile dmhy_regex_finetune --wait
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Codex can run the same commands from this repository after you provide the URL
|
| 49 |
+
and token.
|
| 50 |
+
|
| 51 |
+
## Profiles
|
| 52 |
+
|
| 53 |
+
- `colab/configs/dmhy_regex_finetune.json`: default regex tokenizer fine-tune
|
| 54 |
+
from the published root checkpoint.
|
| 55 |
+
- `colab/configs/dmhy_char_train.json`: character tokenizer training from
|
| 56 |
+
scratch.
|
| 57 |
+
|
| 58 |
+
You can submit a local edited profile instead of a remote profile:
|
| 59 |
+
|
| 60 |
+
```powershell
|
| 61 |
+
python colab_client.py submit --config colab/configs/dmhy_regex_finetune.json --wait
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
The worker writes per-job logs under:
|
| 65 |
+
|
| 66 |
+
```text
|
| 67 |
+
MyDrive/AniFileBERT/worker/jobs/<job-id>/
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
The training runner writes:
|
| 71 |
+
|
| 72 |
+
```text
|
| 73 |
+
MyDrive/AniFileBERT/checkpoints/<profile-name>/
|
| 74 |
+
MyDrive/AniFileBERT/last_run_manifest.json
|
| 75 |
+
```
|
colab/configs/dmhy_char_train.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "dmhy-char-train",
|
| 3 |
+
"repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
|
| 4 |
+
"repo_ref": "main",
|
| 5 |
+
"repo_dir": "/content/AniFileBERT",
|
| 6 |
+
"drive_root": "/content/drive/MyDrive/AniFileBERT",
|
| 7 |
+
"mount_drive": true,
|
| 8 |
+
"pull": true,
|
| 9 |
+
"install": {
|
| 10 |
+
"requirements": true,
|
| 11 |
+
"git_lfs": true,
|
| 12 |
+
"extra_packages": []
|
| 13 |
+
},
|
| 14 |
+
"training": {
|
| 15 |
+
"tokenizer": "char",
|
| 16 |
+
"data_file": "datasets/AnimeName/dmhy_weak_char.jsonl",
|
| 17 |
+
"vocab_file": "datasets/AnimeName/vocab.char.json",
|
| 18 |
+
"save_dir": "{drive_root}/checkpoints/{name}",
|
| 19 |
+
"init_model_dir": null,
|
| 20 |
+
"epochs": 1,
|
| 21 |
+
"batch_size": 128,
|
| 22 |
+
"learning_rate": 0.0003,
|
| 23 |
+
"warmup_steps": 300,
|
| 24 |
+
"train_split": 0.9,
|
| 25 |
+
"max_seq_length": 128,
|
| 26 |
+
"seed": 42,
|
| 27 |
+
"resume_from_checkpoint": "auto",
|
| 28 |
+
"checkpoint_steps": 1000,
|
| 29 |
+
"save_total_limit": 3
|
| 30 |
+
},
|
| 31 |
+
"export": {
|
| 32 |
+
"enabled": true,
|
| 33 |
+
"required": false,
|
| 34 |
+
"output": "{save_dir}/exports/anime_filename_parser.onnx",
|
| 35 |
+
"max_length": "{max_seq_length}"
|
| 36 |
+
},
|
| 37 |
+
"smoke": {
|
| 38 |
+
"enabled": true,
|
| 39 |
+
"required": true,
|
| 40 |
+
"sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub"
|
| 41 |
+
}
|
| 42 |
+
}
|
colab/configs/dmhy_regex_finetune.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "dmhy-regex-finetune",
|
| 3 |
+
"repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
|
| 4 |
+
"repo_ref": "main",
|
| 5 |
+
"repo_dir": "/content/AniFileBERT",
|
| 6 |
+
"drive_root": "/content/drive/MyDrive/AniFileBERT",
|
| 7 |
+
"mount_drive": true,
|
| 8 |
+
"pull": true,
|
| 9 |
+
"install": {
|
| 10 |
+
"requirements": true,
|
| 11 |
+
"git_lfs": true,
|
| 12 |
+
"extra_packages": []
|
| 13 |
+
},
|
| 14 |
+
"training": {
|
| 15 |
+
"tokenizer": "regex",
|
| 16 |
+
"data_file": "datasets/AnimeName/dmhy_weak.jsonl",
|
| 17 |
+
"vocab_file": "datasets/AnimeName/vocab.json",
|
| 18 |
+
"save_dir": "{drive_root}/checkpoints/{name}",
|
| 19 |
+
"init_model_dir": ".",
|
| 20 |
+
"epochs": 1,
|
| 21 |
+
"batch_size": 128,
|
| 22 |
+
"learning_rate": 0.0003,
|
| 23 |
+
"warmup_steps": 300,
|
| 24 |
+
"train_split": 0.9,
|
| 25 |
+
"max_seq_length": 64,
|
| 26 |
+
"seed": 42,
|
| 27 |
+
"resume_from_checkpoint": "auto",
|
| 28 |
+
"checkpoint_steps": 1000,
|
| 29 |
+
"save_total_limit": 3
|
| 30 |
+
},
|
| 31 |
+
"export": {
|
| 32 |
+
"enabled": true,
|
| 33 |
+
"required": false,
|
| 34 |
+
"output": "{save_dir}/exports/anime_filename_parser.onnx",
|
| 35 |
+
"max_length": "{max_seq_length}"
|
| 36 |
+
},
|
| 37 |
+
"smoke": {
|
| 38 |
+
"enabled": true,
|
| 39 |
+
"required": true,
|
| 40 |
+
"sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub"
|
| 41 |
+
}
|
| 42 |
+
}
|
colab/start_worker.ipynb
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 5,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"# AniFileBERT Colab Worker\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"Run the next cell in a GPU runtime. Keep it running while Codex submits training jobs. If free Colab disconnects, open the notebook again and rerun the cell; default profiles resume from the latest Drive checkpoint."
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"from google.colab import drive\n",
|
| 35 |
+
"drive.mount('/content/drive')\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"!git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true\n",
|
| 38 |
+
"%cd /content/AniFileBERT\n",
|
| 39 |
+
"!git pull --ff-only || true\n",
|
| 40 |
+
"!git submodule update --init --recursive\n",
|
| 41 |
+
"!python colab_worker.py\n"
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
]
|
| 45 |
+
}
|
colab_client.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Local client for controlling an active AniFileBERT Colab worker.
|
| 3 |
+
|
| 4 |
+
The worker still has to be started manually in Colab, but once it prints a
|
| 5 |
+
public URL and token this client lets Codex submit training jobs, tail logs, and
|
| 6 |
+
inspect status from the local workspace.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from typing import Any
|
| 18 |
+
import urllib.error
|
| 19 |
+
import urllib.parse
|
| 20 |
+
import urllib.request
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
TERMINAL_STATES = {"success", "failed", "cancelled"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_json(path: str) -> Any:
|
| 27 |
+
return json.loads(Path(path).read_text(encoding="utf-8"))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ColabClient:
|
| 31 |
+
def __init__(self, base_url: str, token: str, timeout: int = 30):
|
| 32 |
+
self.base_url = base_url.rstrip("/")
|
| 33 |
+
self.token = token
|
| 34 |
+
self.timeout = timeout
|
| 35 |
+
|
| 36 |
+
def request(self, method: str, path: str, payload: Any | None = None) -> Any:
|
| 37 |
+
url = self.base_url + path
|
| 38 |
+
data = None
|
| 39 |
+
headers = {"Authorization": f"Bearer {self.token}"}
|
| 40 |
+
if payload is not None:
|
| 41 |
+
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
| 42 |
+
headers["Content-Type"] = "application/json; charset=utf-8"
|
| 43 |
+
|
| 44 |
+
req = urllib.request.Request(url, data=data, headers=headers, method=method)
|
| 45 |
+
try:
|
| 46 |
+
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
| 47 |
+
return json.loads(response.read().decode("utf-8"))
|
| 48 |
+
except urllib.error.HTTPError as exc:
|
| 49 |
+
body = exc.read().decode("utf-8", errors="replace")
|
| 50 |
+
raise RuntimeError(f"{method} {url} failed: HTTP {exc.code}: {body}") from exc
|
| 51 |
+
|
| 52 |
+
def health(self) -> Any:
|
| 53 |
+
return self.request("GET", "/health")
|
| 54 |
+
|
| 55 |
+
def submit(self, payload: dict[str, Any]) -> Any:
|
| 56 |
+
return self.request("POST", "/jobs", payload)
|
| 57 |
+
|
| 58 |
+
def jobs(self) -> Any:
|
| 59 |
+
return self.request("GET", "/jobs")
|
| 60 |
+
|
| 61 |
+
def status(self, job_id: str) -> Any:
|
| 62 |
+
return self.request("GET", f"/jobs/{job_id}")
|
| 63 |
+
|
| 64 |
+
def logs(self, job_id: str, tail: int) -> Any:
|
| 65 |
+
query = urllib.parse.urlencode({"tail": tail})
|
| 66 |
+
return self.request("GET", f"/jobs/{job_id}/logs?{query}")
|
| 67 |
+
|
| 68 |
+
def manifest(self, job_id: str) -> Any:
|
| 69 |
+
return self.request("GET", f"/jobs/{job_id}/manifest")
|
| 70 |
+
|
| 71 |
+
def cancel(self, job_id: str) -> Any:
|
| 72 |
+
return self.request("POST", f"/jobs/{job_id}/cancel", {})
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def print_json(data: Any) -> None:
|
| 76 |
+
print(json.dumps(data, ensure_ascii=False, indent=2))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def require_connection(args: argparse.Namespace) -> ColabClient:
|
| 80 |
+
url = args.url or os.environ.get("ANIFILEBERT_COLAB_URL")
|
| 81 |
+
token = args.token or os.environ.get("ANIFILEBERT_COLAB_TOKEN")
|
| 82 |
+
if not url or not token:
|
| 83 |
+
raise SystemExit(
|
| 84 |
+
"Set ANIFILEBERT_COLAB_URL and ANIFILEBERT_COLAB_TOKEN, "
|
| 85 |
+
"or pass --url and --token."
|
| 86 |
+
)
|
| 87 |
+
return ColabClient(url, token, timeout=args.timeout)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_submit_payload(args: argparse.Namespace) -> dict[str, Any]:
|
| 91 |
+
payload: dict[str, Any] = {}
|
| 92 |
+
if args.config:
|
| 93 |
+
payload["config"] = load_json(args.config)
|
| 94 |
+
if args.profile:
|
| 95 |
+
payload["profile"] = args.profile
|
| 96 |
+
extra_args = list(args.args or []) + list(args.extra_args or [])
|
| 97 |
+
if extra_args:
|
| 98 |
+
payload["args"] = extra_args
|
| 99 |
+
if not payload:
|
| 100 |
+
payload["profile"] = "dmhy_regex_finetune"
|
| 101 |
+
return payload
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def wait_for_job(client: ColabClient, job_id: str, poll: int, tail: int) -> dict[str, Any]:
|
| 105 |
+
last_status = None
|
| 106 |
+
while True:
|
| 107 |
+
status = client.status(job_id)
|
| 108 |
+
if status.get("status") != last_status:
|
| 109 |
+
print_json(status)
|
| 110 |
+
last_status = status.get("status")
|
| 111 |
+
logs = client.logs(job_id, tail=tail)
|
| 112 |
+
log_text = logs.get("log", "")
|
| 113 |
+
if log_text:
|
| 114 |
+
print("\n--- log tail ---")
|
| 115 |
+
print(log_text.rstrip())
|
| 116 |
+
if status.get("status") in TERMINAL_STATES:
|
| 117 |
+
return status
|
| 118 |
+
time.sleep(poll)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def parse_args() -> argparse.Namespace:
|
| 122 |
+
parser = argparse.ArgumentParser(description="Control an active AniFileBERT Colab worker")
|
| 123 |
+
parser.add_argument("--url", help="Worker URL, or ANIFILEBERT_COLAB_URL")
|
| 124 |
+
parser.add_argument("--token", help="Worker token, or ANIFILEBERT_COLAB_TOKEN")
|
| 125 |
+
parser.add_argument("--timeout", type=int, default=30)
|
| 126 |
+
|
| 127 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 128 |
+
|
| 129 |
+
subparsers.add_parser("health", help="Check worker health")
|
| 130 |
+
subparsers.add_parser("jobs", help="List known jobs")
|
| 131 |
+
|
| 132 |
+
submit = subparsers.add_parser("submit", help="Submit a training job")
|
| 133 |
+
submit.add_argument("--config", help="Local JSON config to send to the worker")
|
| 134 |
+
submit.add_argument("--profile", help="Remote profile name under colab/configs")
|
| 135 |
+
submit.add_argument("--arg", dest="args", action="append", default=[], help="Extra arg for colab_train.py")
|
| 136 |
+
submit.add_argument("--wait", action="store_true", help="Poll until the job finishes")
|
| 137 |
+
submit.add_argument("--poll", type=int, default=60, help="Polling interval in seconds")
|
| 138 |
+
submit.add_argument("--tail", type=int, default=80, help="Log lines to show while waiting")
|
| 139 |
+
submit.add_argument("extra_args", nargs=argparse.REMAINDER,
|
| 140 |
+
help="Arguments after -- are passed to colab_train.py")
|
| 141 |
+
|
| 142 |
+
status = subparsers.add_parser("status", help="Show job status")
|
| 143 |
+
status.add_argument("job_id")
|
| 144 |
+
|
| 145 |
+
logs = subparsers.add_parser("logs", help="Show job logs")
|
| 146 |
+
logs.add_argument("job_id")
|
| 147 |
+
logs.add_argument("--tail", type=int, default=200)
|
| 148 |
+
|
| 149 |
+
manifest = subparsers.add_parser("manifest", help="Show job manifest")
|
| 150 |
+
manifest.add_argument("job_id")
|
| 151 |
+
|
| 152 |
+
cancel = subparsers.add_parser("cancel", help="Cancel a running job")
|
| 153 |
+
cancel.add_argument("job_id")
|
| 154 |
+
|
| 155 |
+
return parser.parse_args()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main() -> None:
|
| 159 |
+
args = parse_args()
|
| 160 |
+
client = require_connection(args)
|
| 161 |
+
|
| 162 |
+
if args.command == "health":
|
| 163 |
+
print_json(client.health())
|
| 164 |
+
elif args.command == "jobs":
|
| 165 |
+
print_json(client.jobs())
|
| 166 |
+
elif args.command == "submit":
|
| 167 |
+
job = client.submit(build_submit_payload(args))
|
| 168 |
+
print_json(job)
|
| 169 |
+
if args.wait:
|
| 170 |
+
final_status = wait_for_job(client, job["job_id"], poll=args.poll, tail=args.tail)
|
| 171 |
+
if final_status.get("status") != "success":
|
| 172 |
+
sys.exit(1)
|
| 173 |
+
elif args.command == "status":
|
| 174 |
+
print_json(client.status(args.job_id))
|
| 175 |
+
elif args.command == "logs":
|
| 176 |
+
print(client.logs(args.job_id, args.tail).get("log", ""), end="")
|
| 177 |
+
elif args.command == "manifest":
|
| 178 |
+
print_json(client.manifest(args.job_id))
|
| 179 |
+
elif args.command == "cancel":
|
| 180 |
+
print_json(client.cancel(args.job_id))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
colab_train.py
CHANGED
|
@@ -1,139 +1,543 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
-
"""
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
What it does:
|
| 13 |
-
- Mounts Google Drive (for persistent checkpoints)
|
| 14 |
-
- Clones AniFileBERT repo + AnimeName dataset submodule
|
| 15 |
-
- Installs PyTorch + Transformers dependencies
|
| 16 |
-
- Runs training: train a character-token model with the full DMHY vocab
|
| 17 |
-
- Saves final model to Drive
|
| 18 |
-
|
| 19 |
-
Output:
|
| 20 |
-
- Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
|
| 21 |
-
- Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-weak-char/final/
|
| 22 |
"""
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
import os
|
| 25 |
-
import
|
|
|
|
|
|
|
| 26 |
import subprocess
|
| 27 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def run(cmd, echo=True):
|
| 31 |
-
"""Run a shell command and print output in real time."""
|
| 32 |
-
if echo:
|
| 33 |
-
print(f"\n$ {cmd}")
|
| 34 |
proc = subprocess.Popen(
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
)
|
|
|
|
| 38 |
for line in proc.stdout:
|
| 39 |
print(line, end="")
|
| 40 |
proc.wait()
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
return proc.returncode
|
| 44 |
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
run(
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Codex-friendly Google Colab runner for AniFileBERT training.
|
| 3 |
+
|
| 4 |
+
Typical Colab usage:
|
| 5 |
+
|
| 6 |
+
python colab_train.py --config colab/configs/dmhy_regex_finetune.json
|
| 7 |
+
|
| 8 |
+
This script keeps the Colab side reproducible by putting run parameters in JSON
|
| 9 |
+
profiles. It can clone/update the repo, mount Drive, install dependencies,
|
| 10 |
+
train, optionally export ONNX, run an inference smoke check, and write a run
|
| 11 |
+
manifest that Codex can inspect later.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import copy
|
| 18 |
+
import datetime as dt
|
| 19 |
+
import json
|
| 20 |
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import shlex
|
| 23 |
+
import shutil
|
| 24 |
import subprocess
|
| 25 |
+
import sys
|
| 26 |
+
import traceback
|
| 27 |
+
from typing import Any, Mapping, Sequence
|
| 28 |
+
import urllib.request
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_CONFIG: dict[str, Any] = {
|
| 32 |
+
"name": "dmhy-regex-finetune",
|
| 33 |
+
"repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
|
| 34 |
+
"repo_ref": "main",
|
| 35 |
+
"repo_dir": "/content/AniFileBERT",
|
| 36 |
+
"drive_root": "/content/drive/MyDrive/AniFileBERT",
|
| 37 |
+
"mount_drive": True,
|
| 38 |
+
"pull": True,
|
| 39 |
+
"install": {
|
| 40 |
+
"requirements": True,
|
| 41 |
+
"git_lfs": True,
|
| 42 |
+
"extra_packages": [],
|
| 43 |
+
},
|
| 44 |
+
"training": {
|
| 45 |
+
"tokenizer": "regex",
|
| 46 |
+
"data_file": "datasets/AnimeName/dmhy_weak.jsonl",
|
| 47 |
+
"vocab_file": "datasets/AnimeName/vocab.json",
|
| 48 |
+
"save_dir": "{drive_root}/checkpoints/{name}",
|
| 49 |
+
"init_model_dir": ".",
|
| 50 |
+
"epochs": 1,
|
| 51 |
+
"batch_size": 128,
|
| 52 |
+
"learning_rate": 0.0003,
|
| 53 |
+
"warmup_steps": 300,
|
| 54 |
+
"train_split": 0.9,
|
| 55 |
+
"max_seq_length": 64,
|
| 56 |
+
"seed": 42,
|
| 57 |
+
"limit_samples": None,
|
| 58 |
+
"rebuild_vocab": False,
|
| 59 |
+
"max_vocab_size": None,
|
| 60 |
+
"resume_from_checkpoint": "auto",
|
| 61 |
+
"checkpoint_steps": 1000,
|
| 62 |
+
"save_total_limit": 3,
|
| 63 |
+
"cpu": False,
|
| 64 |
+
"no_shuffle": False,
|
| 65 |
+
"extra_args": [],
|
| 66 |
+
},
|
| 67 |
+
"export": {
|
| 68 |
+
"enabled": True,
|
| 69 |
+
"required": False,
|
| 70 |
+
"output": "{save_dir}/exports/anime_filename_parser.onnx",
|
| 71 |
+
"max_length": "{max_seq_length}",
|
| 72 |
+
"sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub",
|
| 73 |
+
"android_assets_dir": None,
|
| 74 |
+
},
|
| 75 |
+
"smoke": {
|
| 76 |
+
"enabled": True,
|
| 77 |
+
"required": True,
|
| 78 |
+
"sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub",
|
| 79 |
+
},
|
| 80 |
+
"artifacts": {
|
| 81 |
+
"manifest": "{save_dir}/colab_run_manifest.json",
|
| 82 |
+
"latest_manifest": "{drive_root}/last_run_manifest.json",
|
| 83 |
+
},
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
COMMAND_LOG: list[dict[str, Any]] = []
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SafeFormatDict(dict):
|
| 91 |
+
def __missing__(self, key: str) -> str:
|
| 92 |
+
return "{" + key + "}"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def utc_now() -> str:
|
| 96 |
+
return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def deep_merge(base: Mapping[str, Any], override: Mapping[str, Any]) -> dict[str, Any]:
|
| 100 |
+
merged = copy.deepcopy(dict(base))
|
| 101 |
+
for key, value in override.items():
|
| 102 |
+
if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
|
| 103 |
+
merged[key] = deep_merge(merged[key], value)
|
| 104 |
+
else:
|
| 105 |
+
merged[key] = copy.deepcopy(value)
|
| 106 |
+
return merged
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def render_templates(value: Any, context: Mapping[str, Any]) -> Any:
|
| 110 |
+
if isinstance(value, str):
|
| 111 |
+
return value.format_map(SafeFormatDict(context))
|
| 112 |
+
if isinstance(value, list):
|
| 113 |
+
return [render_templates(item, context) for item in value]
|
| 114 |
+
if isinstance(value, dict):
|
| 115 |
+
return {key: render_templates(item, context) for key, item in value.items()}
|
| 116 |
+
return value
|
| 117 |
+
|
| 118 |
|
| 119 |
+
def command_text(args: str | Sequence[Any]) -> str:
|
| 120 |
+
if isinstance(args, str):
|
| 121 |
+
return args
|
| 122 |
+
return " ".join(shlex.quote(str(arg)) for arg in args)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def run(
|
| 126 |
+
args: str | Sequence[Any],
|
| 127 |
+
*,
|
| 128 |
+
cwd: str | os.PathLike[str] | None = None,
|
| 129 |
+
check: bool = True,
|
| 130 |
+
dry_run: bool = False,
|
| 131 |
+
) -> int:
|
| 132 |
+
text = command_text(args)
|
| 133 |
+
entry: dict[str, Any] = {
|
| 134 |
+
"cmd": text,
|
| 135 |
+
"cwd": os.fspath(cwd) if cwd is not None else None,
|
| 136 |
+
"started_at": utc_now(),
|
| 137 |
+
"dry_run": dry_run,
|
| 138 |
+
}
|
| 139 |
+
COMMAND_LOG.append(entry)
|
| 140 |
+
print(f"\n$ {text}")
|
| 141 |
+
if dry_run:
|
| 142 |
+
entry["returncode"] = 0
|
| 143 |
+
entry["finished_at"] = utc_now()
|
| 144 |
+
return 0
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
proc = subprocess.Popen(
|
| 147 |
+
args,
|
| 148 |
+
cwd=cwd,
|
| 149 |
+
shell=isinstance(args, str),
|
| 150 |
+
stdout=subprocess.PIPE,
|
| 151 |
+
stderr=subprocess.STDOUT,
|
| 152 |
+
text=True,
|
| 153 |
+
encoding="utf-8",
|
| 154 |
+
errors="replace",
|
| 155 |
+
bufsize=1,
|
| 156 |
)
|
| 157 |
+
assert proc.stdout is not None
|
| 158 |
for line in proc.stdout:
|
| 159 |
print(line, end="")
|
| 160 |
proc.wait()
|
| 161 |
+
entry["returncode"] = proc.returncode
|
| 162 |
+
entry["finished_at"] = utc_now()
|
| 163 |
+
if check and proc.returncode != 0:
|
| 164 |
+
raise RuntimeError(f"Command failed with exit code {proc.returncode}: {text}")
|
| 165 |
return proc.returncode
|
| 166 |
|
| 167 |
|
| 168 |
+
def parse_args() -> argparse.Namespace:
|
| 169 |
+
parser = argparse.ArgumentParser(description="Run AniFileBERT training in Colab")
|
| 170 |
+
parser.add_argument("--config", help="JSON profile path or URL")
|
| 171 |
+
parser.add_argument("--profile", help="Profile name under colab/configs without .json")
|
| 172 |
+
parser.add_argument("--repo-url", help="Override repository URL")
|
| 173 |
+
parser.add_argument("--repo-ref", help="Override branch, tag, or commit to checkout")
|
| 174 |
+
parser.add_argument("--repo-dir", help="Override Colab repository directory")
|
| 175 |
+
parser.add_argument("--drive-root", help="Override Google Drive output root")
|
| 176 |
+
parser.add_argument("--save-dir", help="Override checkpoint output directory")
|
| 177 |
+
parser.add_argument("--epochs", type=float, help="Override training epochs")
|
| 178 |
+
parser.add_argument("--batch-size", type=int, help="Override per-device batch size")
|
| 179 |
+
parser.add_argument("--learning-rate", type=float, help="Override learning rate")
|
| 180 |
+
parser.add_argument("--warmup-steps", type=int, help="Override warmup steps")
|
| 181 |
+
parser.add_argument("--limit-samples", type=int, help="Use only the first N dataset rows")
|
| 182 |
+
parser.add_argument("--skip-install", action="store_true", help="Do not install pip or git-lfs dependencies")
|
| 183 |
+
parser.add_argument("--skip-export", action="store_true", help="Do not run ONNX export")
|
| 184 |
+
parser.add_argument("--skip-smoke", action="store_true", help="Do not run inference smoke check")
|
| 185 |
+
parser.add_argument("--no-mount-drive", action="store_true", help="Do not mount Google Drive")
|
| 186 |
+
parser.add_argument("--no-pull", action="store_true", help="Do not pull an existing checkout")
|
| 187 |
+
parser.add_argument("--dry-run", action="store_true", help="Print commands and write no training outputs")
|
| 188 |
+
parser.add_argument("--print-config", action="store_true", help="Print resolved config before running")
|
| 189 |
+
return parser.parse_args()
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def load_json_source(source: str | None, *, required: bool) -> dict[str, Any]:
|
| 193 |
+
if not source:
|
| 194 |
+
return {}
|
| 195 |
+
if source.startswith(("http://", "https://")):
|
| 196 |
+
with urllib.request.urlopen(source) as response:
|
| 197 |
+
return json.loads(response.read().decode("utf-8"))
|
| 198 |
+
|
| 199 |
+
candidates = [Path(source), Path(__file__).resolve().parent / source]
|
| 200 |
+
for candidate in candidates:
|
| 201 |
+
if candidate.is_file():
|
| 202 |
+
return json.loads(candidate.read_text(encoding="utf-8"))
|
| 203 |
+
if required:
|
| 204 |
+
raise FileNotFoundError(f"Config file not found: {source}")
|
| 205 |
+
return {}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_config(args: argparse.Namespace) -> dict[str, Any]:
|
| 209 |
+
config_source = args.config
|
| 210 |
+
required = bool(args.config)
|
| 211 |
+
if config_source is None and args.profile:
|
| 212 |
+
config_source = os.fspath(Path("colab") / "configs" / f"{args.profile}.json")
|
| 213 |
+
required = True
|
| 214 |
+
|
| 215 |
+
profile_config = load_json_source(config_source, required=required)
|
| 216 |
+
config = deep_merge(DEFAULT_CONFIG, profile_config)
|
| 217 |
+
|
| 218 |
+
if args.repo_url:
|
| 219 |
+
config["repo_url"] = args.repo_url
|
| 220 |
+
if args.repo_ref:
|
| 221 |
+
config["repo_ref"] = args.repo_ref
|
| 222 |
+
if args.repo_dir:
|
| 223 |
+
config["repo_dir"] = args.repo_dir
|
| 224 |
+
if args.drive_root:
|
| 225 |
+
config["drive_root"] = args.drive_root
|
| 226 |
+
if args.no_mount_drive:
|
| 227 |
+
config["mount_drive"] = False
|
| 228 |
+
if args.no_pull:
|
| 229 |
+
config["pull"] = False
|
| 230 |
+
if args.skip_install:
|
| 231 |
+
config["install"]["requirements"] = False
|
| 232 |
+
config["install"]["git_lfs"] = False
|
| 233 |
+
config["install"]["extra_packages"] = []
|
| 234 |
+
if args.skip_export:
|
| 235 |
+
config["export"]["enabled"] = False
|
| 236 |
+
if args.skip_smoke:
|
| 237 |
+
config["smoke"]["enabled"] = False
|
| 238 |
+
|
| 239 |
+
training = config["training"]
|
| 240 |
+
for arg_name, key in [
|
| 241 |
+
("save_dir", "save_dir"),
|
| 242 |
+
("epochs", "epochs"),
|
| 243 |
+
("batch_size", "batch_size"),
|
| 244 |
+
("learning_rate", "learning_rate"),
|
| 245 |
+
("warmup_steps", "warmup_steps"),
|
| 246 |
+
("limit_samples", "limit_samples"),
|
| 247 |
+
]:
|
| 248 |
+
value = getattr(args, arg_name)
|
| 249 |
+
if value is not None:
|
| 250 |
+
training[key] = value
|
| 251 |
+
|
| 252 |
+
return resolve_config(config)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def resolve_config(config: dict[str, Any]) -> dict[str, Any]:
|
| 256 |
+
context: dict[str, Any] = {
|
| 257 |
+
"name": config["name"],
|
| 258 |
+
"repo_url": config["repo_url"],
|
| 259 |
+
"repo_ref": config.get("repo_ref") or "",
|
| 260 |
+
"repo_dir": config["repo_dir"],
|
| 261 |
+
"drive_root": config["drive_root"],
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
training = render_templates(config["training"], context)
|
| 265 |
+
context.update(training)
|
| 266 |
+
if not training.get("save_dir"):
|
| 267 |
+
training["save_dir"] = os.path.join(config["drive_root"], "checkpoints", config["name"])
|
| 268 |
+
training = render_templates(training, {**context, **training})
|
| 269 |
+
context.update(training)
|
| 270 |
+
context["save_dir"] = training["save_dir"]
|
| 271 |
+
context["final_model_dir"] = os.path.join(training["save_dir"], "final")
|
| 272 |
+
|
| 273 |
+
resolved = copy.deepcopy(config)
|
| 274 |
+
resolved["training"] = training
|
| 275 |
+
resolved["export"] = render_templates(config["export"], context)
|
| 276 |
+
resolved["smoke"] = render_templates(config["smoke"], context)
|
| 277 |
+
resolved["artifacts"] = render_templates(config["artifacts"], context)
|
| 278 |
+
return resolved
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def maybe_mount_drive(config: Mapping[str, Any]) -> None:
|
| 282 |
+
if not config.get("mount_drive", True):
|
| 283 |
+
print("Google Drive mount disabled.")
|
| 284 |
+
return
|
| 285 |
+
try:
|
| 286 |
+
from google.colab import drive # type: ignore
|
| 287 |
+
except Exception:
|
| 288 |
+
print("[WARN] google.colab is unavailable; skipping Drive mount.")
|
| 289 |
+
return
|
| 290 |
+
print("Mounting Google Drive...")
|
| 291 |
+
drive.mount("/content/drive")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def install_git_lfs_if_needed(config: Mapping[str, Any], *, dry_run: bool) -> None:
|
| 295 |
+
if not config.get("install", {}).get("git_lfs", True):
|
| 296 |
+
return
|
| 297 |
+
if shutil.which("git-lfs"):
|
| 298 |
+
run(["git", "lfs", "install"], check=False, dry_run=dry_run)
|
| 299 |
+
return
|
| 300 |
+
if Path("/content").exists():
|
| 301 |
+
print("Installing git-lfs for Hugging Face model artifacts...")
|
| 302 |
+
run(["apt-get", "update"], check=False, dry_run=dry_run)
|
| 303 |
+
run(["apt-get", "install", "-y", "git-lfs"], dry_run=dry_run)
|
| 304 |
+
run(["git", "lfs", "install"], check=False, dry_run=dry_run)
|
| 305 |
+
else:
|
| 306 |
+
print("[WARN] git-lfs not found. Existing LFS pointers may not contain model weights.")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def is_git_repo(path: Path) -> bool:
|
| 310 |
+
return (path / ".git").exists()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def prepare_repo(config: Mapping[str, Any], *, dry_run: bool) -> Path:
|
| 314 |
+
repo_dir = Path(config["repo_dir"])
|
| 315 |
+
repo_url = config["repo_url"]
|
| 316 |
+
repo_ref = config.get("repo_ref")
|
| 317 |
+
|
| 318 |
+
if not is_git_repo(repo_dir):
|
| 319 |
+
if repo_dir.exists() and any(repo_dir.iterdir()):
|
| 320 |
+
raise RuntimeError(f"{repo_dir} exists but is not a git checkout")
|
| 321 |
+
repo_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 322 |
+
run(["git", "clone", "--recursive", repo_url, os.fspath(repo_dir)], dry_run=dry_run)
|
| 323 |
+
else:
|
| 324 |
+
print(f"Using existing repository checkout: {repo_dir}")
|
| 325 |
+
|
| 326 |
+
if repo_ref:
|
| 327 |
+
run(["git", "fetch", "--all", "--tags"], cwd=repo_dir, check=False, dry_run=dry_run)
|
| 328 |
+
run(["git", "checkout", str(repo_ref)], cwd=repo_dir, dry_run=dry_run)
|
| 329 |
+
|
| 330 |
+
if config.get("pull", True):
|
| 331 |
+
run(["git", "pull", "--ff-only"], cwd=repo_dir, check=False, dry_run=dry_run)
|
| 332 |
+
|
| 333 |
+
run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir, dry_run=dry_run)
|
| 334 |
+
if shutil.which("git-lfs"):
|
| 335 |
+
run(["git", "lfs", "pull"], cwd=repo_dir, check=False, dry_run=dry_run)
|
| 336 |
+
|
| 337 |
+
return repo_dir
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def install_python_deps(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
|
| 341 |
+
install = config.get("install", {})
|
| 342 |
+
if install.get("requirements", True):
|
| 343 |
+
run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=repo_dir, dry_run=dry_run)
|
| 344 |
+
for package in install.get("extra_packages", []):
|
| 345 |
+
run([sys.executable, "-m", "pip", "install", str(package)], cwd=repo_dir, dry_run=dry_run)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def verify_runtime(repo_dir: Path, *, dry_run: bool) -> None:
|
| 349 |
+
run(["nvidia-smi"], cwd=repo_dir, check=False, dry_run=dry_run)
|
| 350 |
run(
|
| 351 |
+
[
|
| 352 |
+
sys.executable,
|
| 353 |
+
"-c",
|
| 354 |
+
"import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')",
|
| 355 |
+
],
|
| 356 |
+
cwd=repo_dir,
|
| 357 |
+
check=False,
|
| 358 |
+
dry_run=dry_run,
|
| 359 |
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def add_arg(cmd: list[str], flag: str, value: Any) -> None:
|
| 363 |
+
if value is None or value is False:
|
| 364 |
+
return
|
| 365 |
+
if value is True:
|
| 366 |
+
cmd.append(flag)
|
| 367 |
+
else:
|
| 368 |
+
cmd.extend([flag, str(value)])
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def build_train_command(training: Mapping[str, Any]) -> list[str]:
|
| 372 |
+
cmd = [sys.executable, "train.py"]
|
| 373 |
+
for key, flag in [
|
| 374 |
+
("tokenizer", "--tokenizer"),
|
| 375 |
+
("data_file", "--data-file"),
|
| 376 |
+
("vocab_file", "--vocab-file"),
|
| 377 |
+
("save_dir", "--save-dir"),
|
| 378 |
+
("init_model_dir", "--init-model-dir"),
|
| 379 |
+
("epochs", "--epochs"),
|
| 380 |
+
("batch_size", "--batch-size"),
|
| 381 |
+
("learning_rate", "--learning-rate"),
|
| 382 |
+
("warmup_steps", "--warmup-steps"),
|
| 383 |
+
("train_split", "--train-split"),
|
| 384 |
+
("max_seq_length", "--max-seq-length"),
|
| 385 |
+
("seed", "--seed"),
|
| 386 |
+
("limit_samples", "--limit-samples"),
|
| 387 |
+
("max_vocab_size", "--max-vocab-size"),
|
| 388 |
+
("resume_from_checkpoint", "--resume-from-checkpoint"),
|
| 389 |
+
("checkpoint_steps", "--checkpoint-steps"),
|
| 390 |
+
("save_total_limit", "--save-total-limit"),
|
| 391 |
+
]:
|
| 392 |
+
add_arg(cmd, flag, training.get(key))
|
| 393 |
+
add_arg(cmd, "--rebuild-vocab", training.get("rebuild_vocab"))
|
| 394 |
+
add_arg(cmd, "--cpu", training.get("cpu"))
|
| 395 |
+
add_arg(cmd, "--no-shuffle", training.get("no_shuffle"))
|
| 396 |
+
cmd.extend(str(arg) for arg in training.get("extra_args", []))
|
| 397 |
+
return cmd
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def run_training(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
|
| 401 |
+
training = config["training"]
|
| 402 |
+
if not dry_run:
|
| 403 |
+
Path(training["save_dir"]).mkdir(parents=True, exist_ok=True)
|
| 404 |
+
run(build_train_command(training), cwd=repo_dir, dry_run=dry_run)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def run_export(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
|
| 408 |
+
export = config["export"]
|
| 409 |
+
if not export.get("enabled", True):
|
| 410 |
+
print("ONNX export disabled.")
|
| 411 |
+
return
|
| 412 |
+
cmd = [
|
| 413 |
+
sys.executable,
|
| 414 |
+
"export_onnx.py",
|
| 415 |
+
"--model-dir",
|
| 416 |
+
os.path.join(config["training"]["save_dir"], "final"),
|
| 417 |
+
"--output",
|
| 418 |
+
export["output"],
|
| 419 |
+
"--max-length",
|
| 420 |
+
str(export["max_length"]),
|
| 421 |
+
]
|
| 422 |
+
add_arg(cmd, "--sample", export.get("sample"))
|
| 423 |
+
add_arg(cmd, "--android-assets-dir", export.get("android_assets_dir"))
|
| 424 |
+
try:
|
| 425 |
+
run(cmd, cwd=repo_dir, dry_run=dry_run)
|
| 426 |
+
except Exception:
|
| 427 |
+
if export.get("required", False):
|
| 428 |
+
raise
|
| 429 |
+
print("[WARN] ONNX export failed, but export.required is false.")
|
| 430 |
+
traceback.print_exc()
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def run_smoke(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
|
| 434 |
+
smoke = config["smoke"]
|
| 435 |
+
if not smoke.get("enabled", True):
|
| 436 |
+
print("Inference smoke check disabled.")
|
| 437 |
+
return
|
| 438 |
+
cmd = [
|
| 439 |
+
sys.executable,
|
| 440 |
+
"inference.py",
|
| 441 |
+
"--model-dir",
|
| 442 |
+
os.path.join(config["training"]["save_dir"], "final"),
|
| 443 |
+
smoke["sample"],
|
| 444 |
+
]
|
| 445 |
+
try:
|
| 446 |
+
run(cmd, cwd=repo_dir, dry_run=dry_run)
|
| 447 |
+
except Exception:
|
| 448 |
+
if smoke.get("required", True):
|
| 449 |
+
raise
|
| 450 |
+
print("[WARN] Smoke check failed, but smoke.required is false.")
|
| 451 |
+
traceback.print_exc()
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def git_commit(repo_dir: Path, *, dry_run: bool) -> str | None:
|
| 455 |
+
if dry_run:
|
| 456 |
+
return None
|
| 457 |
+
try:
|
| 458 |
+
return subprocess.check_output(
|
| 459 |
+
["git", "rev-parse", "HEAD"],
|
| 460 |
+
cwd=repo_dir,
|
| 461 |
+
text=True,
|
| 462 |
+
encoding="utf-8",
|
| 463 |
+
errors="replace",
|
| 464 |
+
).strip()
|
| 465 |
+
except Exception:
|
| 466 |
+
return None
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def write_json(path: str | os.PathLike[str], data: Mapping[str, Any], *, dry_run: bool) -> None:
|
| 470 |
+
print(f"Writing manifest: {path}")
|
| 471 |
+
if dry_run:
|
| 472 |
+
return
|
| 473 |
+
output_path = Path(path)
|
| 474 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 475 |
+
output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def write_manifests(
|
| 479 |
+
config: Mapping[str, Any],
|
| 480 |
+
repo_dir: Path,
|
| 481 |
+
*,
|
| 482 |
+
status: str,
|
| 483 |
+
started_at: str,
|
| 484 |
+
error: str | None,
|
| 485 |
+
dry_run: bool,
|
| 486 |
+
) -> None:
|
| 487 |
+
save_dir = config["training"]["save_dir"]
|
| 488 |
+
manifest = {
|
| 489 |
+
"status": status,
|
| 490 |
+
"name": config["name"],
|
| 491 |
+
"started_at": started_at,
|
| 492 |
+
"finished_at": utc_now(),
|
| 493 |
+
"repo_url": config["repo_url"],
|
| 494 |
+
"repo_ref": config.get("repo_ref"),
|
| 495 |
+
"repo_commit": git_commit(repo_dir, dry_run=dry_run),
|
| 496 |
+
"repo_dir": os.fspath(repo_dir),
|
| 497 |
+
"save_dir": save_dir,
|
| 498 |
+
"final_model_dir": os.path.join(save_dir, "final"),
|
| 499 |
+
"onnx_output": config["export"].get("output") if config["export"].get("enabled") else None,
|
| 500 |
+
"config": config,
|
| 501 |
+
"commands": COMMAND_LOG,
|
| 502 |
+
"error": error,
|
| 503 |
+
}
|
| 504 |
+
artifacts = config["artifacts"]
|
| 505 |
+
write_json(artifacts["manifest"], manifest, dry_run=dry_run)
|
| 506 |
+
if artifacts.get("latest_manifest"):
|
| 507 |
+
write_json(artifacts["latest_manifest"], manifest, dry_run=dry_run)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def main() -> None:
|
| 511 |
+
args = parse_args()
|
| 512 |
+
started_at = utc_now()
|
| 513 |
+
config = load_config(args)
|
| 514 |
+
|
| 515 |
+
if args.print_config:
|
| 516 |
+
print(json.dumps(config, ensure_ascii=False, indent=2))
|
| 517 |
+
|
| 518 |
+
repo_dir = Path(config["repo_dir"])
|
| 519 |
+
status = "failed"
|
| 520 |
+
error: str | None = None
|
| 521 |
+
try:
|
| 522 |
+
maybe_mount_drive(config)
|
| 523 |
+
install_git_lfs_if_needed(config, dry_run=args.dry_run)
|
| 524 |
+
repo_dir = prepare_repo(config, dry_run=args.dry_run)
|
| 525 |
+
install_python_deps(config, repo_dir, dry_run=args.dry_run)
|
| 526 |
+
verify_runtime(repo_dir, dry_run=args.dry_run)
|
| 527 |
+
run_training(config, repo_dir, dry_run=args.dry_run)
|
| 528 |
+
run_export(config, repo_dir, dry_run=args.dry_run)
|
| 529 |
+
run_smoke(config, repo_dir, dry_run=args.dry_run)
|
| 530 |
+
status = "success"
|
| 531 |
+
except Exception as exc:
|
| 532 |
+
error = f"{type(exc).__name__}: {exc}"
|
| 533 |
+
raise
|
| 534 |
+
finally:
|
| 535 |
+
write_manifests(config, repo_dir, status=status, started_at=started_at, error=error, dry_run=args.dry_run)
|
| 536 |
+
|
| 537 |
+
print("\nDone.")
|
| 538 |
+
print(f"Final model: {os.path.join(config['training']['save_dir'], 'final')}")
|
| 539 |
+
print(f"Manifest: {config['artifacts']['manifest']}")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if __name__ == "__main__":
|
| 543 |
+
main()
|
colab_worker.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Small HTTP worker for running AniFileBERT training jobs on Google Colab.
|
| 3 |
+
|
| 4 |
+
Start this inside a Colab runtime:
|
| 5 |
+
|
| 6 |
+
python colab_worker.py
|
| 7 |
+
|
| 8 |
+
The worker exposes a token-protected local HTTP API and, by default, starts a
|
| 9 |
+
Cloudflare Quick Tunnel so Codex on your local machine can submit jobs.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import platform
|
| 19 |
+
import re
|
| 20 |
+
import secrets
|
| 21 |
+
import shutil
|
| 22 |
+
import signal
|
| 23 |
+
import subprocess
|
| 24 |
+
import sys
|
| 25 |
+
import threading
|
| 26 |
+
import time
|
| 27 |
+
import traceback
|
| 28 |
+
from http import HTTPStatus
|
| 29 |
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
| 30 |
+
from typing import Any
|
| 31 |
+
from urllib.parse import parse_qs, urlparse
|
| 32 |
+
import urllib.request
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
TERMINAL_STATES = {"success", "failed", "cancelled"}
|
| 36 |
+
TUNNEL_URL_RE = re.compile(r"https://[-a-zA-Z0-9.]+\.trycloudflare\.com")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def utc_timestamp() -> str:
|
| 40 |
+
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def json_dumps(data: Any) -> str:
|
| 44 |
+
return json.dumps(data, ensure_ascii=False, indent=2)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def read_tail(path: Path, lines: int) -> str:
|
| 48 |
+
if not path.is_file():
|
| 49 |
+
return ""
|
| 50 |
+
if lines <= 0:
|
| 51 |
+
return path.read_text(encoding="utf-8", errors="replace")
|
| 52 |
+
|
| 53 |
+
chunk_size = 8192
|
| 54 |
+
data = b""
|
| 55 |
+
with path.open("rb") as f:
|
| 56 |
+
f.seek(0, os.SEEK_END)
|
| 57 |
+
pos = f.tell()
|
| 58 |
+
while pos > 0 and data.count(b"\n") <= lines:
|
| 59 |
+
read_size = min(chunk_size, pos)
|
| 60 |
+
pos -= read_size
|
| 61 |
+
f.seek(pos)
|
| 62 |
+
data = f.read(read_size) + data
|
| 63 |
+
return b"\n".join(data.splitlines()[-lines:]).decode("utf-8", errors="replace")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def download_cloudflared(path: Path) -> Path:
|
| 67 |
+
if path.is_file():
|
| 68 |
+
return path
|
| 69 |
+
|
| 70 |
+
existing = shutil.which("cloudflared")
|
| 71 |
+
if existing:
|
| 72 |
+
return Path(existing)
|
| 73 |
+
|
| 74 |
+
arch = platform.machine().lower()
|
| 75 |
+
if arch in {"x86_64", "amd64"}:
|
| 76 |
+
suffix = "linux-amd64"
|
| 77 |
+
elif arch in {"aarch64", "arm64"}:
|
| 78 |
+
suffix = "linux-arm64"
|
| 79 |
+
else:
|
| 80 |
+
raise RuntimeError(f"Unsupported CPU architecture for cloudflared: {arch}")
|
| 81 |
+
|
| 82 |
+
url = f"https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-{suffix}"
|
| 83 |
+
print(f"Downloading cloudflared: {url}", flush=True)
|
| 84 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
urllib.request.urlretrieve(url, path)
|
| 86 |
+
path.chmod(0o755)
|
| 87 |
+
return path
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class WorkerState:
|
| 91 |
+
def __init__(self, repo_dir: Path, jobs_dir: Path):
|
| 92 |
+
self.repo_dir = repo_dir
|
| 93 |
+
self.jobs_dir = jobs_dir
|
| 94 |
+
self.jobs_dir.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
self.jobs: dict[str, dict[str, Any]] = {}
|
| 96 |
+
self.lock = threading.RLock()
|
| 97 |
+
|
| 98 |
+
def list_jobs(self) -> list[dict[str, Any]]:
|
| 99 |
+
with self.lock:
|
| 100 |
+
return [self._public_job(job) for job in self.jobs.values()]
|
| 101 |
+
|
| 102 |
+
def get_job(self, job_id: str) -> dict[str, Any] | None:
|
| 103 |
+
with self.lock:
|
| 104 |
+
job = self.jobs.get(job_id)
|
| 105 |
+
return self._public_job(job) if job else None
|
| 106 |
+
|
| 107 |
+
def get_job_internal(self, job_id: str) -> dict[str, Any] | None:
|
| 108 |
+
with self.lock:
|
| 109 |
+
return self.jobs.get(job_id)
|
| 110 |
+
|
| 111 |
+
def active_job(self) -> dict[str, Any] | None:
|
| 112 |
+
with self.lock:
|
| 113 |
+
for job in self.jobs.values():
|
| 114 |
+
if job["status"] not in TERMINAL_STATES:
|
| 115 |
+
return job
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
def start_job(self, payload: dict[str, Any]) -> dict[str, Any]:
|
| 119 |
+
with self.lock:
|
| 120 |
+
active = self.active_job()
|
| 121 |
+
if active is not None:
|
| 122 |
+
raise RuntimeError(f"Job already running: {active['job_id']}")
|
| 123 |
+
|
| 124 |
+
job_id = time.strftime("%Y%m%d-%H%M%S", time.gmtime()) + "-" + secrets.token_hex(3)
|
| 125 |
+
job_dir = self.jobs_dir / job_id
|
| 126 |
+
job_dir.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
log_path = job_dir / "worker.log"
|
| 128 |
+
config_path: Path | None = None
|
| 129 |
+
|
| 130 |
+
cmd = [sys.executable, "colab_train.py"]
|
| 131 |
+
config = self._job_config(payload)
|
| 132 |
+
config.setdefault("artifacts", {})
|
| 133 |
+
config["artifacts"]["manifest"] = os.fspath(job_dir / "colab_run_manifest.json")
|
| 134 |
+
config_path = job_dir / "config.json"
|
| 135 |
+
config_path.write_text(json_dumps(config), encoding="utf-8")
|
| 136 |
+
cmd.extend(["--config", os.fspath(config_path)])
|
| 137 |
+
|
| 138 |
+
for arg in payload.get("args", []):
|
| 139 |
+
cmd.append(str(arg))
|
| 140 |
+
|
| 141 |
+
job = {
|
| 142 |
+
"job_id": job_id,
|
| 143 |
+
"status": "queued",
|
| 144 |
+
"created_at": utc_timestamp(),
|
| 145 |
+
"started_at": None,
|
| 146 |
+
"finished_at": None,
|
| 147 |
+
"returncode": None,
|
| 148 |
+
"cmd": cmd,
|
| 149 |
+
"cwd": os.fspath(self.repo_dir),
|
| 150 |
+
"job_dir": os.fspath(job_dir),
|
| 151 |
+
"log_path": os.fspath(log_path),
|
| 152 |
+
"config_path": os.fspath(config_path) if config_path else None,
|
| 153 |
+
"error": None,
|
| 154 |
+
"process": None,
|
| 155 |
+
}
|
| 156 |
+
self.jobs[job_id] = job
|
| 157 |
+
|
| 158 |
+
thread = threading.Thread(target=self._run_job, args=(job_id,), daemon=True)
|
| 159 |
+
thread.start()
|
| 160 |
+
return self._public_job(job)
|
| 161 |
+
|
| 162 |
+
def _job_config(self, payload: dict[str, Any]) -> dict[str, Any]:
|
| 163 |
+
if "config" in payload:
|
| 164 |
+
return json.loads(json.dumps(payload["config"], ensure_ascii=False))
|
| 165 |
+
|
| 166 |
+
profile = str(payload.get("profile", "dmhy_regex_finetune"))
|
| 167 |
+
profile_path = self.repo_dir / "colab" / "configs" / f"{profile}.json"
|
| 168 |
+
if not profile_path.is_file():
|
| 169 |
+
raise FileNotFoundError(f"Profile not found: {profile_path}")
|
| 170 |
+
return json.loads(profile_path.read_text(encoding="utf-8"))
|
| 171 |
+
|
| 172 |
+
def cancel_job(self, job_id: str) -> dict[str, Any]:
|
| 173 |
+
with self.lock:
|
| 174 |
+
job = self.jobs.get(job_id)
|
| 175 |
+
if job is None:
|
| 176 |
+
raise KeyError(job_id)
|
| 177 |
+
process: subprocess.Popen[str] | None = job.get("process")
|
| 178 |
+
if job["status"] in TERMINAL_STATES:
|
| 179 |
+
return self._public_job(job)
|
| 180 |
+
job["status"] = "cancelled"
|
| 181 |
+
job["finished_at"] = utc_timestamp()
|
| 182 |
+
|
| 183 |
+
if process and process.poll() is None:
|
| 184 |
+
try:
|
| 185 |
+
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
| 186 |
+
except Exception:
|
| 187 |
+
process.terminate()
|
| 188 |
+
return self.get_job(job_id) or {}
|
| 189 |
+
|
| 190 |
+
def _run_job(self, job_id: str) -> None:
|
| 191 |
+
job = self.get_job_internal(job_id)
|
| 192 |
+
if job is None:
|
| 193 |
+
return
|
| 194 |
+
log_path = Path(job["log_path"])
|
| 195 |
+
try:
|
| 196 |
+
with self.lock:
|
| 197 |
+
job["status"] = "running"
|
| 198 |
+
job["started_at"] = utc_timestamp()
|
| 199 |
+
|
| 200 |
+
with log_path.open("w", encoding="utf-8", errors="replace") as log:
|
| 201 |
+
log.write(f"job_id={job_id}\n")
|
| 202 |
+
log.write(f"cwd={job['cwd']}\n")
|
| 203 |
+
log.write("$ " + " ".join(job["cmd"]) + "\n\n")
|
| 204 |
+
log.flush()
|
| 205 |
+
|
| 206 |
+
process = subprocess.Popen(
|
| 207 |
+
job["cmd"],
|
| 208 |
+
cwd=job["cwd"],
|
| 209 |
+
stdout=subprocess.PIPE,
|
| 210 |
+
stderr=subprocess.STDOUT,
|
| 211 |
+
text=True,
|
| 212 |
+
encoding="utf-8",
|
| 213 |
+
errors="replace",
|
| 214 |
+
bufsize=1,
|
| 215 |
+
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
| 216 |
+
)
|
| 217 |
+
with self.lock:
|
| 218 |
+
job["process"] = process
|
| 219 |
+
|
| 220 |
+
assert process.stdout is not None
|
| 221 |
+
for line in process.stdout:
|
| 222 |
+
log.write(line)
|
| 223 |
+
log.flush()
|
| 224 |
+
print(line, end="", flush=True)
|
| 225 |
+
process.wait()
|
| 226 |
+
|
| 227 |
+
with self.lock:
|
| 228 |
+
job["returncode"] = process.returncode
|
| 229 |
+
if job["status"] != "cancelled":
|
| 230 |
+
job["status"] = "success" if process.returncode == 0 else "failed"
|
| 231 |
+
job["finished_at"] = utc_timestamp()
|
| 232 |
+
job["process"] = None
|
| 233 |
+
except Exception as exc:
|
| 234 |
+
with log_path.open("a", encoding="utf-8", errors="replace") as log:
|
| 235 |
+
traceback.print_exc(file=log)
|
| 236 |
+
with self.lock:
|
| 237 |
+
job["status"] = "failed"
|
| 238 |
+
job["finished_at"] = utc_timestamp()
|
| 239 |
+
job["error"] = f"{type(exc).__name__}: {exc}"
|
| 240 |
+
job["process"] = None
|
| 241 |
+
|
| 242 |
+
def _public_job(self, job: dict[str, Any]) -> dict[str, Any]:
|
| 243 |
+
public = {key: value for key, value in job.items() if key != "process"}
|
| 244 |
+
return public
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def make_handler(state: WorkerState, token: str):
|
| 248 |
+
class Handler(BaseHTTPRequestHandler):
|
| 249 |
+
server_version = "AniFileBERTColabWorker/1.0"
|
| 250 |
+
|
| 251 |
+
def log_message(self, fmt: str, *args: Any) -> None:
|
| 252 |
+
print(f"[{utc_timestamp()}] {self.address_string()} {fmt % args}", flush=True)
|
| 253 |
+
|
| 254 |
+
def do_GET(self) -> None:
|
| 255 |
+
self._handle("GET")
|
| 256 |
+
|
| 257 |
+
def do_POST(self) -> None:
|
| 258 |
+
self._handle("POST")
|
| 259 |
+
|
| 260 |
+
def _handle(self, method: str) -> None:
|
| 261 |
+
parsed = urlparse(self.path)
|
| 262 |
+
path = parsed.path.rstrip("/") or "/"
|
| 263 |
+
parts = [part for part in path.split("/") if part]
|
| 264 |
+
try:
|
| 265 |
+
if not self._authorized():
|
| 266 |
+
self._send({"error": "unauthorized"}, HTTPStatus.UNAUTHORIZED)
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
if method == "GET" and path == "/health":
|
| 270 |
+
self._send(
|
| 271 |
+
{
|
| 272 |
+
"ok": True,
|
| 273 |
+
"repo_dir": os.fspath(state.repo_dir),
|
| 274 |
+
"jobs_dir": os.fspath(state.jobs_dir),
|
| 275 |
+
"active_job": state.active_job()["job_id"] if state.active_job() else None,
|
| 276 |
+
}
|
| 277 |
+
)
|
| 278 |
+
return
|
| 279 |
+
|
| 280 |
+
if method == "GET" and path == "/jobs":
|
| 281 |
+
self._send({"jobs": state.list_jobs()})
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
if method == "POST" and path == "/jobs":
|
| 285 |
+
payload = self._read_json()
|
| 286 |
+
job = state.start_job(payload)
|
| 287 |
+
self._send(job, HTTPStatus.ACCEPTED)
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
if len(parts) >= 2 and parts[0] == "jobs":
|
| 291 |
+
job_id = parts[1]
|
| 292 |
+
if method == "GET" and len(parts) == 2:
|
| 293 |
+
job = state.get_job(job_id)
|
| 294 |
+
if job is None:
|
| 295 |
+
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
|
| 296 |
+
else:
|
| 297 |
+
self._send(job)
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
if method == "GET" and len(parts) == 3 and parts[2] == "logs":
|
| 301 |
+
query = parse_qs(parsed.query)
|
| 302 |
+
tail = int(query.get("tail", ["200"])[0])
|
| 303 |
+
job = state.get_job_internal(job_id)
|
| 304 |
+
if job is None:
|
| 305 |
+
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
|
| 306 |
+
else:
|
| 307 |
+
self._send({"job_id": job_id, "log": read_tail(Path(job["log_path"]), tail)})
|
| 308 |
+
return
|
| 309 |
+
|
| 310 |
+
if method == "GET" and len(parts) == 3 and parts[2] == "manifest":
|
| 311 |
+
job = state.get_job_internal(job_id)
|
| 312 |
+
if job is None:
|
| 313 |
+
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
|
| 314 |
+
else:
|
| 315 |
+
manifest = self._find_manifest(job)
|
| 316 |
+
if manifest is None:
|
| 317 |
+
self._send({"error": "manifest not found"}, HTTPStatus.NOT_FOUND)
|
| 318 |
+
else:
|
| 319 |
+
self._send(json.loads(manifest.read_text(encoding="utf-8")))
|
| 320 |
+
return
|
| 321 |
+
|
| 322 |
+
if method == "POST" and len(parts) == 3 and parts[2] == "cancel":
|
| 323 |
+
try:
|
| 324 |
+
self._send(state.cancel_job(job_id))
|
| 325 |
+
except KeyError:
|
| 326 |
+
self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
|
| 327 |
+
return
|
| 328 |
+
|
| 329 |
+
self._send({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
| 330 |
+
except Exception as exc:
|
| 331 |
+
traceback.print_exc()
|
| 332 |
+
self._send({"error": f"{type(exc).__name__}: {exc}"}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
| 333 |
+
|
| 334 |
+
def _authorized(self) -> bool:
|
| 335 |
+
header = self.headers.get("Authorization", "")
|
| 336 |
+
if header == f"Bearer {token}":
|
| 337 |
+
return True
|
| 338 |
+
return self.headers.get("X-Colab-Token") == token
|
| 339 |
+
|
| 340 |
+
def _read_json(self) -> dict[str, Any]:
|
| 341 |
+
length = int(self.headers.get("Content-Length", "0"))
|
| 342 |
+
if length == 0:
|
| 343 |
+
return {}
|
| 344 |
+
raw = self.rfile.read(length)
|
| 345 |
+
return json.loads(raw.decode("utf-8"))
|
| 346 |
+
|
| 347 |
+
def _find_manifest(self, job: dict[str, Any]) -> Path | None:
|
| 348 |
+
config_path = job.get("config_path")
|
| 349 |
+
if config_path and Path(config_path).is_file():
|
| 350 |
+
config = json.loads(Path(config_path).read_text(encoding="utf-8"))
|
| 351 |
+
training = config.get("training", {})
|
| 352 |
+
save_dir = training.get("save_dir")
|
| 353 |
+
if save_dir:
|
| 354 |
+
manifest = Path(save_dir) / "colab_run_manifest.json"
|
| 355 |
+
if manifest.is_file():
|
| 356 |
+
return manifest
|
| 357 |
+
job_manifest = Path(job["job_dir"]) / "colab_run_manifest.json"
|
| 358 |
+
return job_manifest if job_manifest.is_file() else None
|
| 359 |
+
|
| 360 |
+
def _send(self, data: Any, status: HTTPStatus = HTTPStatus.OK) -> None:
|
| 361 |
+
raw = json_dumps(data).encode("utf-8")
|
| 362 |
+
self.send_response(status.value)
|
| 363 |
+
self.send_header("Content-Type", "application/json; charset=utf-8")
|
| 364 |
+
self.send_header("Content-Length", str(len(raw)))
|
| 365 |
+
self.end_headers()
|
| 366 |
+
self.wfile.write(raw)
|
| 367 |
+
|
| 368 |
+
return Handler
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def start_tunnel(port: int, binary_path: Path) -> subprocess.Popen[str]:
|
| 372 |
+
cloudflared = download_cloudflared(binary_path)
|
| 373 |
+
cmd = [
|
| 374 |
+
os.fspath(cloudflared),
|
| 375 |
+
"tunnel",
|
| 376 |
+
"--url",
|
| 377 |
+
f"http://127.0.0.1:{port}",
|
| 378 |
+
"--no-autoupdate",
|
| 379 |
+
]
|
| 380 |
+
proc = subprocess.Popen(
|
| 381 |
+
cmd,
|
| 382 |
+
stdout=subprocess.PIPE,
|
| 383 |
+
stderr=subprocess.STDOUT,
|
| 384 |
+
text=True,
|
| 385 |
+
encoding="utf-8",
|
| 386 |
+
errors="replace",
|
| 387 |
+
bufsize=1,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def pump() -> None:
|
| 391 |
+
assert proc.stdout is not None
|
| 392 |
+
for line in proc.stdout:
|
| 393 |
+
print(line, end="", flush=True)
|
| 394 |
+
match = TUNNEL_URL_RE.search(line)
|
| 395 |
+
if match:
|
| 396 |
+
print("\nCOLAB_WORKER_URL=" + match.group(0), flush=True)
|
| 397 |
+
|
| 398 |
+
threading.Thread(target=pump, daemon=True).start()
|
| 399 |
+
return proc
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def parse_args() -> argparse.Namespace:
|
| 403 |
+
parser = argparse.ArgumentParser(description="Start the AniFileBERT Colab worker")
|
| 404 |
+
parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host")
|
| 405 |
+
parser.add_argument("--port", type=int, default=7860, help="HTTP bind port")
|
| 406 |
+
parser.add_argument("--repo-dir", default="/content/AniFileBERT", help="AniFileBERT checkout path in Colab")
|
| 407 |
+
parser.add_argument("--jobs-dir", default="/content/drive/MyDrive/AniFileBERT/worker/jobs")
|
| 408 |
+
parser.add_argument("--token", default=os.environ.get("ANIFILEBERT_COLAB_TOKEN"))
|
| 409 |
+
parser.add_argument("--tunnel", choices=["cloudflare", "none"], default="cloudflare")
|
| 410 |
+
parser.add_argument("--cloudflared-path", default="/tmp/anifilebert-cloudflared")
|
| 411 |
+
return parser.parse_args()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def main() -> None:
|
| 415 |
+
args = parse_args()
|
| 416 |
+
token = args.token or secrets.token_urlsafe(24)
|
| 417 |
+
repo_dir = Path(args.repo_dir)
|
| 418 |
+
if not repo_dir.is_dir():
|
| 419 |
+
raise RuntimeError(f"Repo directory does not exist: {repo_dir}")
|
| 420 |
+
|
| 421 |
+
state = WorkerState(repo_dir=repo_dir, jobs_dir=Path(args.jobs_dir))
|
| 422 |
+
server = ThreadingHTTPServer((args.host, args.port), make_handler(state, token))
|
| 423 |
+
tunnel_proc: subprocess.Popen[str] | None = None
|
| 424 |
+
|
| 425 |
+
print("=" * 72)
|
| 426 |
+
print("AniFileBERT Colab worker is starting")
|
| 427 |
+
print(f"Local URL: http://{args.host}:{args.port}")
|
| 428 |
+
print(f"COLAB_WORKER_TOKEN={token}")
|
| 429 |
+
print("Keep this Colab cell running while Codex uses the worker.")
|
| 430 |
+
print("=" * 72, flush=True)
|
| 431 |
+
|
| 432 |
+
if args.tunnel == "cloudflare":
|
| 433 |
+
tunnel_proc = start_tunnel(args.port, Path(args.cloudflared_path))
|
| 434 |
+
else:
|
| 435 |
+
print("Tunnel disabled. Use the local URL from inside the Colab runtime.", flush=True)
|
| 436 |
+
|
| 437 |
+
try:
|
| 438 |
+
server.serve_forever()
|
| 439 |
+
finally:
|
| 440 |
+
server.server_close()
|
| 441 |
+
if tunnel_proc and tunnel_proc.poll() is None:
|
| 442 |
+
tunnel_proc.terminate()
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if __name__ == "__main__":
|
| 446 |
+
main()
|
train.py
CHANGED
|
@@ -27,7 +27,7 @@ from transformers import (
|
|
| 27 |
from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
|
| 28 |
|
| 29 |
from config import Config
|
| 30 |
-
from tokenizer import AnimeTokenizer, create_tokenizer
|
| 31 |
from model import create_model, print_model_summary, count_parameters
|
| 32 |
from dataset import AnimeDataset, align_tokens_for_tokenizer
|
| 33 |
|
|
@@ -64,8 +64,8 @@ def compute_metrics(p):
|
|
| 64 |
|
| 65 |
def parse_args() -> argparse.Namespace:
|
| 66 |
parser = argparse.ArgumentParser(description="Train anime filename parser")
|
| 67 |
-
parser.add_argument("--tokenizer", choices=["regex", "char"], default=
|
| 68 |
-
help="Tokenizer variant for A/B testing")
|
| 69 |
parser.add_argument("--data-file", default=None, help="Training JSONL file")
|
| 70 |
parser.add_argument("--vocab-file", default=None,
|
| 71 |
help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
|
|
@@ -84,11 +84,58 @@ def parse_args() -> argparse.Namespace:
|
|
| 84 |
help="Rebuild vocab from the selected data file before training")
|
| 85 |
parser.add_argument("--max-vocab-size", type=int, default=None,
|
| 86 |
help="Optional vocab cap used with --rebuild-vocab")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
parser.add_argument("--cpu", action="store_true", help="Force CPU training")
|
| 88 |
parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
|
|
|
|
|
|
|
| 89 |
return parser.parse_args()
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str:
|
| 93 |
if explicit_path:
|
| 94 |
return explicit_path
|
|
@@ -96,6 +143,79 @@ def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Op
|
|
| 96 |
return os.path.join(os.path.dirname(data_file), name)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str,
|
| 100 |
max_size: Optional[int] = None) -> None:
|
| 101 |
token_lists: List[List[str]] = []
|
|
@@ -115,9 +235,10 @@ def main():
|
|
| 115 |
config = Config()
|
| 116 |
if args.data_file is not None:
|
| 117 |
config.data_file = args.data_file
|
|
|
|
| 118 |
if args.save_dir is not None:
|
| 119 |
config.save_dir = args.save_dir
|
| 120 |
-
elif
|
| 121 |
config.save_dir = "./checkpoints_char"
|
| 122 |
if args.epochs is not None:
|
| 123 |
config.num_epochs = args.epochs
|
|
@@ -131,6 +252,8 @@ def main():
|
|
| 131 |
config.train_split = args.train_split
|
| 132 |
if args.max_seq_length is not None:
|
| 133 |
config.max_seq_length = args.max_seq_length
|
|
|
|
|
|
|
| 134 |
|
| 135 |
random.seed(args.seed)
|
| 136 |
np.random.seed(args.seed)
|
|
@@ -143,18 +266,20 @@ def main():
|
|
| 143 |
all_data = all_data[:args.limit_samples]
|
| 144 |
if not args.no_shuffle:
|
| 145 |
random.shuffle(all_data)
|
|
|
|
| 146 |
|
| 147 |
# Load tokenizer
|
| 148 |
print("Loading tokenizer...")
|
| 149 |
-
vocab_path = resolve_vocab_path(config.data_file,
|
| 150 |
-
tokenizer = create_tokenizer(
|
| 151 |
if args.rebuild_vocab or not os.path.isfile(vocab_path):
|
| 152 |
max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
|
| 153 |
-
print(f" Building {
|
| 154 |
build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
|
| 155 |
-
tokenizer = create_tokenizer(
|
| 156 |
-
print(f" Variant: {
|
| 157 |
print(f" Vocab size: {tokenizer.vocab_size}")
|
|
|
|
| 158 |
|
| 159 |
# Update config with actual vocab size
|
| 160 |
config.vocab_size = tokenizer.vocab_size
|
|
@@ -163,9 +288,22 @@ def main():
|
|
| 163 |
if args.init_model_dir:
|
| 164 |
print(f"Loading model for fine-tuning: {args.init_model_dir}")
|
| 165 |
model = BertForTokenClassification.from_pretrained(args.init_model_dir)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
model.config.num_labels = config.num_labels
|
| 170 |
model.config.id2label = config.id2label
|
| 171 |
model.config.label2id = config.label2id
|
|
@@ -212,6 +350,8 @@ def main():
|
|
| 212 |
use_cpu = args.cpu or not torch.cuda.is_available()
|
| 213 |
use_fp16 = not use_cpu
|
| 214 |
print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
|
|
|
|
|
|
|
| 215 |
|
| 216 |
# Training arguments
|
| 217 |
training_args = TrainingArguments(
|
|
@@ -220,15 +360,16 @@ def main():
|
|
| 220 |
per_device_train_batch_size=config.batch_size,
|
| 221 |
per_device_eval_batch_size=config.batch_size,
|
| 222 |
eval_strategy="epoch",
|
| 223 |
-
save_strategy=
|
|
|
|
| 224 |
logging_steps=config.log_interval,
|
| 225 |
learning_rate=config.learning_rate,
|
| 226 |
weight_decay=config.weight_decay,
|
| 227 |
warmup_steps=config.warmup_steps,
|
| 228 |
use_cpu=use_cpu,
|
| 229 |
report_to="none",
|
| 230 |
-
save_total_limit=
|
| 231 |
-
load_best_model_at_end=
|
| 232 |
metric_for_best_model="f1",
|
| 233 |
greater_is_better=True,
|
| 234 |
dataloader_num_workers=config.num_workers,
|
|
@@ -250,12 +391,19 @@ def main():
|
|
| 250 |
|
| 251 |
# Train
|
| 252 |
print("Starting training...")
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
# Set proper label mappings in model config before saving
|
| 256 |
model.config.id2label = config.id2label
|
| 257 |
model.config.label2id = config.label2id
|
| 258 |
-
model.config.tokenizer_variant =
|
| 259 |
model.config.max_seq_length = config.max_seq_length
|
| 260 |
|
| 261 |
# Save final model
|
|
|
|
| 27 |
from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
|
| 28 |
|
| 29 |
from config import Config
|
| 30 |
+
from tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer
|
| 31 |
from model import create_model, print_model_summary, count_parameters
|
| 32 |
from dataset import AnimeDataset, align_tokens_for_tokenizer
|
| 33 |
|
|
|
|
| 64 |
|
| 65 |
def parse_args() -> argparse.Namespace:
|
| 66 |
parser = argparse.ArgumentParser(description="Train anime filename parser")
|
| 67 |
+
parser.add_argument("--tokenizer", choices=["regex", "char"], default=None,
|
| 68 |
+
help="Tokenizer variant for A/B testing. Defaults to dataset metadata")
|
| 69 |
parser.add_argument("--data-file", default=None, help="Training JSONL file")
|
| 70 |
parser.add_argument("--vocab-file", default=None,
|
| 71 |
help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
|
|
|
|
| 84 |
help="Rebuild vocab from the selected data file before training")
|
| 85 |
parser.add_argument("--max-vocab-size", type=int, default=None,
|
| 86 |
help="Optional vocab cap used with --rebuild-vocab")
|
| 87 |
+
parser.add_argument("--checkpoint-steps", type=int, default=None,
|
| 88 |
+
help="Save resumable checkpoints every N steps instead of only at epoch end")
|
| 89 |
+
parser.add_argument("--save-total-limit", type=int, default=2,
|
| 90 |
+
help="Maximum number of checkpoints to keep")
|
| 91 |
parser.add_argument("--cpu", action="store_true", help="Force CPU training")
|
| 92 |
parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
|
| 93 |
+
parser.add_argument("--resume-from-checkpoint", default=None,
|
| 94 |
+
help="Resume Trainer state from a checkpoint directory, or 'auto' for the latest checkpoint")
|
| 95 |
return parser.parse_args()
|
| 96 |
|
| 97 |
|
| 98 |
+
def detect_tokenizer_variant(
|
| 99 |
+
data_file: str,
|
| 100 |
+
explicit_variant: Optional[str],
|
| 101 |
+
explicit_vocab_path: Optional[str],
|
| 102 |
+
sample_size: int = 256,
|
| 103 |
+
) -> str:
|
| 104 |
+
"""Infer tokenizer variant from CLI, dataset metadata, or vocab filename."""
|
| 105 |
+
if explicit_variant:
|
| 106 |
+
return explicit_variant
|
| 107 |
+
|
| 108 |
+
variants = set()
|
| 109 |
+
char_like = 0
|
| 110 |
+
inspected = 0
|
| 111 |
+
with open(data_file, "r", encoding="utf-8") as f:
|
| 112 |
+
for line in f:
|
| 113 |
+
if inspected >= sample_size:
|
| 114 |
+
break
|
| 115 |
+
line = line.strip()
|
| 116 |
+
if not line:
|
| 117 |
+
continue
|
| 118 |
+
item = json.loads(line)
|
| 119 |
+
inspected += 1
|
| 120 |
+
variant = item.get("tokenizer_variant")
|
| 121 |
+
if variant:
|
| 122 |
+
variants.add(variant)
|
| 123 |
+
tokens = item.get("tokens", [])
|
| 124 |
+
filename = item.get("filename")
|
| 125 |
+
if filename is not None and tokens == list(filename):
|
| 126 |
+
char_like += 1
|
| 127 |
+
|
| 128 |
+
if len(variants) == 1:
|
| 129 |
+
return next(iter(variants))
|
| 130 |
+
if len(variants) > 1:
|
| 131 |
+
raise ValueError(f"Mixed tokenizer_variant values in {data_file}: {sorted(variants)}")
|
| 132 |
+
if explicit_vocab_path and ".char" in os.path.basename(explicit_vocab_path).lower():
|
| 133 |
+
return "char"
|
| 134 |
+
if inspected and char_like / inspected >= 0.95:
|
| 135 |
+
return "char"
|
| 136 |
+
return "regex"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str:
|
| 140 |
if explicit_path:
|
| 141 |
return explicit_path
|
|
|
|
| 143 |
return os.path.join(os.path.dirname(data_file), name)
|
| 144 |
|
| 145 |
|
| 146 |
+
def latest_checkpoint(save_dir: str) -> Optional[str]:
|
| 147 |
+
if not os.path.isdir(save_dir):
|
| 148 |
+
return None
|
| 149 |
+
checkpoints = []
|
| 150 |
+
for name in os.listdir(save_dir):
|
| 151 |
+
if not name.startswith("checkpoint-"):
|
| 152 |
+
continue
|
| 153 |
+
path = os.path.join(save_dir, name)
|
| 154 |
+
if not os.path.isdir(path):
|
| 155 |
+
continue
|
| 156 |
+
try:
|
| 157 |
+
step = int(name.split("-")[-1])
|
| 158 |
+
except ValueError:
|
| 159 |
+
continue
|
| 160 |
+
checkpoints.append((step, path))
|
| 161 |
+
if not checkpoints:
|
| 162 |
+
return None
|
| 163 |
+
return max(checkpoints)[1]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None:
|
| 167 |
+
variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")}
|
| 168 |
+
if variants and variants != {tokenizer_variant}:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"Dataset tokenizer_variant {sorted(variants)} does not match selected tokenizer "
|
| 171 |
+
f"'{tokenizer_variant}'. Pass --tokenizer explicitly only when this is intentional."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def remap_token_embeddings(
|
| 176 |
+
model: BertForTokenClassification,
|
| 177 |
+
old_vocab: Dict[str, int],
|
| 178 |
+
new_vocab: Dict[str, int],
|
| 179 |
+
pad_token_id: int,
|
| 180 |
+
) -> int:
|
| 181 |
+
"""
|
| 182 |
+
Replace the input embedding table for a changed vocabulary.
|
| 183 |
+
|
| 184 |
+
resize_token_embeddings() preserves rows by numeric ID, which is unsafe when
|
| 185 |
+
two tokenizers assign different tokens to the same ID. This remaps by token
|
| 186 |
+
string and randomly initializes tokens that do not exist in the old vocab.
|
| 187 |
+
"""
|
| 188 |
+
old_embeddings = model.get_input_embeddings()
|
| 189 |
+
old_weight = old_embeddings.weight.data
|
| 190 |
+
embedding_dim = old_weight.shape[1]
|
| 191 |
+
new_embeddings = torch.nn.Embedding(
|
| 192 |
+
len(new_vocab),
|
| 193 |
+
embedding_dim,
|
| 194 |
+
padding_idx=pad_token_id,
|
| 195 |
+
device=old_weight.device,
|
| 196 |
+
dtype=old_weight.dtype,
|
| 197 |
+
)
|
| 198 |
+
torch.nn.init.normal_(
|
| 199 |
+
new_embeddings.weight,
|
| 200 |
+
mean=0.0,
|
| 201 |
+
std=getattr(model.config, "initializer_range", 0.02),
|
| 202 |
+
)
|
| 203 |
+
if pad_token_id is not None and 0 <= pad_token_id < len(new_vocab):
|
| 204 |
+
new_embeddings.weight.data[pad_token_id].zero_()
|
| 205 |
+
|
| 206 |
+
copied = 0
|
| 207 |
+
for token, new_id in new_vocab.items():
|
| 208 |
+
old_id = old_vocab.get(token)
|
| 209 |
+
if old_id is None or old_id >= old_weight.shape[0]:
|
| 210 |
+
continue
|
| 211 |
+
new_embeddings.weight.data[new_id].copy_(old_weight[old_id])
|
| 212 |
+
copied += 1
|
| 213 |
+
|
| 214 |
+
model.set_input_embeddings(new_embeddings)
|
| 215 |
+
model.config.vocab_size = len(new_vocab)
|
| 216 |
+
return copied
|
| 217 |
+
|
| 218 |
+
|
| 219 |
def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str,
|
| 220 |
max_size: Optional[int] = None) -> None:
|
| 221 |
token_lists: List[List[str]] = []
|
|
|
|
| 235 |
config = Config()
|
| 236 |
if args.data_file is not None:
|
| 237 |
config.data_file = args.data_file
|
| 238 |
+
tokenizer_variant = detect_tokenizer_variant(config.data_file, args.tokenizer, args.vocab_file)
|
| 239 |
if args.save_dir is not None:
|
| 240 |
config.save_dir = args.save_dir
|
| 241 |
+
elif tokenizer_variant == "char":
|
| 242 |
config.save_dir = "./checkpoints_char"
|
| 243 |
if args.epochs is not None:
|
| 244 |
config.num_epochs = args.epochs
|
|
|
|
| 252 |
config.train_split = args.train_split
|
| 253 |
if args.max_seq_length is not None:
|
| 254 |
config.max_seq_length = args.max_seq_length
|
| 255 |
+
elif tokenizer_variant == "char":
|
| 256 |
+
config.max_seq_length = max(config.max_seq_length, 128)
|
| 257 |
|
| 258 |
random.seed(args.seed)
|
| 259 |
np.random.seed(args.seed)
|
|
|
|
| 266 |
all_data = all_data[:args.limit_samples]
|
| 267 |
if not args.no_shuffle:
|
| 268 |
random.shuffle(all_data)
|
| 269 |
+
validate_dataset_tokenizer_metadata(all_data, tokenizer_variant)
|
| 270 |
|
| 271 |
# Load tokenizer
|
| 272 |
print("Loading tokenizer...")
|
| 273 |
+
vocab_path = resolve_vocab_path(config.data_file, tokenizer_variant, args.vocab_file)
|
| 274 |
+
tokenizer = create_tokenizer(tokenizer_variant)
|
| 275 |
if args.rebuild_vocab or not os.path.isfile(vocab_path):
|
| 276 |
max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
|
| 277 |
+
print(f" Building {tokenizer_variant} vocab: {vocab_path} (max_size={max_vocab_size})")
|
| 278 |
build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
|
| 279 |
+
tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_path)
|
| 280 |
+
print(f" Variant: {tokenizer_variant}")
|
| 281 |
print(f" Vocab size: {tokenizer.vocab_size}")
|
| 282 |
+
print(f" Max sequence length: {config.max_seq_length}")
|
| 283 |
|
| 284 |
# Update config with actual vocab size
|
| 285 |
config.vocab_size = tokenizer.vocab_size
|
|
|
|
| 288 |
if args.init_model_dir:
|
| 289 |
print(f"Loading model for fine-tuning: {args.init_model_dir}")
|
| 290 |
model = BertForTokenClassification.from_pretrained(args.init_model_dir)
|
| 291 |
+
init_tokenizer = load_tokenizer(args.init_model_dir)
|
| 292 |
+
init_variant = getattr(init_tokenizer, "tokenizer_variant", None)
|
| 293 |
+
if init_variant != tokenizer_variant:
|
| 294 |
+
print(f" WARNING: tokenizer variant changes during fine-tune: {init_variant} -> {tokenizer_variant}")
|
| 295 |
+
print(" Token embeddings will be remapped by token string; unmatched tokens are newly initialized.")
|
| 296 |
+
if model.config.vocab_size != config.vocab_size or init_tokenizer.get_vocab() != tokenizer.get_vocab():
|
| 297 |
+
copied = remap_token_embeddings(
|
| 298 |
+
model=model,
|
| 299 |
+
old_vocab=init_tokenizer.get_vocab(),
|
| 300 |
+
new_vocab=tokenizer.get_vocab(),
|
| 301 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 302 |
+
)
|
| 303 |
+
print(
|
| 304 |
+
f" Remapped token embeddings: copied {copied:,}/{config.vocab_size:,} "
|
| 305 |
+
f"tokens from init checkpoint"
|
| 306 |
+
)
|
| 307 |
model.config.num_labels = config.num_labels
|
| 308 |
model.config.id2label = config.id2label
|
| 309 |
model.config.label2id = config.label2id
|
|
|
|
| 350 |
use_cpu = args.cpu or not torch.cuda.is_available()
|
| 351 |
use_fp16 = not use_cpu
|
| 352 |
print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
|
| 353 |
+
save_strategy = "steps" if args.checkpoint_steps else "epoch"
|
| 354 |
+
load_best_model_at_end = args.checkpoint_steps is None
|
| 355 |
|
| 356 |
# Training arguments
|
| 357 |
training_args = TrainingArguments(
|
|
|
|
| 360 |
per_device_train_batch_size=config.batch_size,
|
| 361 |
per_device_eval_batch_size=config.batch_size,
|
| 362 |
eval_strategy="epoch",
|
| 363 |
+
save_strategy=save_strategy,
|
| 364 |
+
save_steps=args.checkpoint_steps,
|
| 365 |
logging_steps=config.log_interval,
|
| 366 |
learning_rate=config.learning_rate,
|
| 367 |
weight_decay=config.weight_decay,
|
| 368 |
warmup_steps=config.warmup_steps,
|
| 369 |
use_cpu=use_cpu,
|
| 370 |
report_to="none",
|
| 371 |
+
save_total_limit=args.save_total_limit,
|
| 372 |
+
load_best_model_at_end=load_best_model_at_end,
|
| 373 |
metric_for_best_model="f1",
|
| 374 |
greater_is_better=True,
|
| 375 |
dataloader_num_workers=config.num_workers,
|
|
|
|
| 391 |
|
| 392 |
# Train
|
| 393 |
print("Starting training...")
|
| 394 |
+
resume_from_checkpoint = args.resume_from_checkpoint
|
| 395 |
+
if resume_from_checkpoint == "auto":
|
| 396 |
+
resume_from_checkpoint = latest_checkpoint(config.save_dir)
|
| 397 |
+
if resume_from_checkpoint:
|
| 398 |
+
print(f"Resuming from latest checkpoint: {resume_from_checkpoint}")
|
| 399 |
+
else:
|
| 400 |
+
print("No checkpoint found; starting a fresh training run.")
|
| 401 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 402 |
|
| 403 |
# Set proper label mappings in model config before saving
|
| 404 |
model.config.id2label = config.id2label
|
| 405 |
model.config.label2id = config.label2id
|
| 406 |
+
model.config.tokenizer_variant = tokenizer_variant
|
| 407 |
model.config.max_seq_length = config.max_seq_length
|
| 408 |
|
| 409 |
# Save final model
|