ModerRAS commited on
Commit
e458112
·
1 Parent(s): beb8c7e

Add Codex Colab training workflow

Browse files
AGENTS.md CHANGED
@@ -67,6 +67,66 @@ Export for Android:
67
  python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --android-assets-dir ../../scraper/src/main/assets/anime_parser
68
  ```
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ## Validation Expectations
71
 
72
  - For parser or tokenizer changes, run `python inference.py --model-dir . ...`
 
67
  python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --android-assets-dir ../../scraper/src/main/assets/anime_parser
68
  ```
69
 
70
+ ## Codex-Controlled Colab Training
71
+
72
+ Free Colab cannot be treated as an always-on remote machine. Use it as a
73
+ short-lived GPU worker only after the user manually opens a Colab runtime and
74
+ starts the worker cell. Do not assume Codex can wake Colab by itself.
75
+
76
+ Before relying on the Colab flow, make sure the Colab helper files have been
77
+ pushed to the Hugging Face model repo, or the user has uploaded them manually:
78
+ `colab_worker.py`, `colab_client.py`, `colab_train.py`, and `colab/`.
79
+
80
+ Ask the user to start a Colab GPU runtime with:
81
+
82
+ ```python
83
+ from google.colab import drive
84
+ drive.mount("/content/drive")
85
+
86
+ !git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true
87
+ %cd /content/AniFileBERT
88
+ !git pull --ff-only || true
89
+ !git submodule update --init --recursive
90
+ !python colab_worker.py
91
+ ```
92
+
93
+ The worker prints `COLAB_WORKER_URL=...` and `COLAB_WORKER_TOKEN=...`. After
94
+ the user provides those values, set them for local commands:
95
+
96
+ ```powershell
97
+ $env:ANIFILEBERT_COLAB_URL="https://...trycloudflare.com"
98
+ $env:ANIFILEBERT_COLAB_TOKEN="..."
99
+ python colab_client.py health
100
+ ```
101
+
102
+ Submit the default regex fine-tune:
103
+
104
+ ```powershell
105
+ python colab_client.py submit --profile dmhy_regex_finetune --wait
106
+ ```
107
+
108
+ Submit the character tokenizer run only when intentional:
109
+
110
+ ```powershell
111
+ python colab_client.py submit --profile dmhy_char_train --wait
112
+ ```
113
+
114
+ Useful follow-up commands:
115
+
116
+ ```powershell
117
+ python colab_client.py jobs
118
+ python colab_client.py status <job-id>
119
+ python colab_client.py logs <job-id> --tail 200
120
+ python colab_client.py manifest <job-id>
121
+ python colab_client.py cancel <job-id>
122
+ ```
123
+
124
+ The default Colab profiles save checkpoints to Google Drive every 1000 steps
125
+ and resume with `resume_from_checkpoint: "auto"`, so if free Colab disconnects,
126
+ ask the user to restart the worker and submit the same profile again. Artifacts
127
+ land under `MyDrive/AniFileBERT/checkpoints/<profile-name>/`, and worker logs
128
+ land under `MyDrive/AniFileBERT/worker/jobs/<job-id>/`.
129
+
130
  ## Validation Expectations
131
 
132
  - For parser or tokenizer changes, run `python inference.py --model-dir . ...`
README.md CHANGED
@@ -199,9 +199,17 @@ python export_onnx.py --model-dir checkpoints/dmhy-finetune/final --output expor
199
 
200
  ## Google Colab Training
201
 
202
- Upload and run [`colab_train.py`](colab_train.py) in a Colab GPU runtime.
203
- It will mount Google Drive, clone both repos, install dependencies, and run
204
- the full training pipeline. Checkpoints are saved to your Drive automatically.
 
 
 
 
 
 
 
 
205
 
206
  ## Repository Layout
207
 
 
199
 
200
  ## Google Colab Training
201
 
202
+ For Codex-controlled short Colab sessions, see [`colab/README.md`](colab/README.md).
203
+ Free Colab still has to be started manually, but once `colab_worker.py` is
204
+ running Codex can submit jobs through `colab_client.py`, tail logs, and inspect
205
+ status. Checkpoints live on Google Drive and default profiles resume from the
206
+ latest checkpoint automatically.
207
+
208
+ Manual one-shot runs are also supported:
209
+
210
+ ```bash
211
+ python colab_train.py --profile dmhy_regex_finetune
212
+ ```
213
 
214
  ## Repository Layout
215
 
colab/README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codex + Colab Training
2
+
3
+ Free Colab cannot be used as an always-on remote machine. The practical setup is:
4
+
5
+ 1. Open a Colab GPU runtime when you want to train.
6
+ 2. Start the lightweight worker in one cell.
7
+ 3. Give Codex the printed worker URL and token.
8
+ 4. Codex submits jobs while that Colab session is alive.
9
+ 5. Checkpoints and manifests stay on Google Drive, so the next session can resume.
10
+
11
+ ## Start a Colab Session
12
+
13
+ Run this in a Colab code cell:
14
+
15
+ ```python
16
+ from google.colab import drive
17
+ drive.mount("/content/drive")
18
+
19
+ !git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true
20
+ %cd /content/AniFileBERT
21
+ !git pull --ff-only || true
22
+ !git submodule update --init --recursive
23
+ !python colab_worker.py
24
+ ```
25
+
26
+ The cell prints:
27
+
28
+ ```text
29
+ COLAB_WORKER_URL=https://...trycloudflare.com
30
+ COLAB_WORKER_TOKEN=...
31
+ ```
32
+
33
+ Keep that cell running. If Colab disconnects, start it again; default profiles
34
+ save every 1000 steps and resume from the latest Drive checkpoint because they
35
+ use `checkpoint_steps: 1000` and `resume_from_checkpoint: "auto"`.
36
+
37
+ ## Let Codex Submit a Job
38
+
39
+ On the local machine:
40
+
41
+ ```powershell
42
+ $env:ANIFILEBERT_COLAB_URL="https://...trycloudflare.com"
43
+ $env:ANIFILEBERT_COLAB_TOKEN="..."
44
+ python colab_client.py health
45
+ python colab_client.py submit --profile dmhy_regex_finetune --wait
46
+ ```
47
+
48
+ Codex can run the same commands from this repository after you provide the URL
49
+ and token.
50
+
51
+ ## Profiles
52
+
53
+ - `colab/configs/dmhy_regex_finetune.json`: default regex tokenizer fine-tune
54
+ from the published root checkpoint.
55
+ - `colab/configs/dmhy_char_train.json`: character tokenizer training from
56
+ scratch.
57
+
58
+ You can submit a local edited profile instead of a remote profile:
59
+
60
+ ```powershell
61
+ python colab_client.py submit --config colab/configs/dmhy_regex_finetune.json --wait
62
+ ```
63
+
64
+ The worker writes per-job logs under:
65
+
66
+ ```text
67
+ MyDrive/AniFileBERT/worker/jobs/<job-id>/
68
+ ```
69
+
70
+ The training runner writes:
71
+
72
+ ```text
73
+ MyDrive/AniFileBERT/checkpoints/<profile-name>/
74
+ MyDrive/AniFileBERT/last_run_manifest.json
75
+ ```
colab/configs/dmhy_char_train.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "dmhy-char-train",
3
+ "repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
4
+ "repo_ref": "main",
5
+ "repo_dir": "/content/AniFileBERT",
6
+ "drive_root": "/content/drive/MyDrive/AniFileBERT",
7
+ "mount_drive": true,
8
+ "pull": true,
9
+ "install": {
10
+ "requirements": true,
11
+ "git_lfs": true,
12
+ "extra_packages": []
13
+ },
14
+ "training": {
15
+ "tokenizer": "char",
16
+ "data_file": "datasets/AnimeName/dmhy_weak_char.jsonl",
17
+ "vocab_file": "datasets/AnimeName/vocab.char.json",
18
+ "save_dir": "{drive_root}/checkpoints/{name}",
19
+ "init_model_dir": null,
20
+ "epochs": 1,
21
+ "batch_size": 128,
22
+ "learning_rate": 0.0003,
23
+ "warmup_steps": 300,
24
+ "train_split": 0.9,
25
+ "max_seq_length": 128,
26
+ "seed": 42,
27
+ "resume_from_checkpoint": "auto",
28
+ "checkpoint_steps": 1000,
29
+ "save_total_limit": 3
30
+ },
31
+ "export": {
32
+ "enabled": true,
33
+ "required": false,
34
+ "output": "{save_dir}/exports/anime_filename_parser.onnx",
35
+ "max_length": "{max_seq_length}"
36
+ },
37
+ "smoke": {
38
+ "enabled": true,
39
+ "required": true,
40
+ "sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub"
41
+ }
42
+ }
colab/configs/dmhy_regex_finetune.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "dmhy-regex-finetune",
3
+ "repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
4
+ "repo_ref": "main",
5
+ "repo_dir": "/content/AniFileBERT",
6
+ "drive_root": "/content/drive/MyDrive/AniFileBERT",
7
+ "mount_drive": true,
8
+ "pull": true,
9
+ "install": {
10
+ "requirements": true,
11
+ "git_lfs": true,
12
+ "extra_packages": []
13
+ },
14
+ "training": {
15
+ "tokenizer": "regex",
16
+ "data_file": "datasets/AnimeName/dmhy_weak.jsonl",
17
+ "vocab_file": "datasets/AnimeName/vocab.json",
18
+ "save_dir": "{drive_root}/checkpoints/{name}",
19
+ "init_model_dir": ".",
20
+ "epochs": 1,
21
+ "batch_size": 128,
22
+ "learning_rate": 0.0003,
23
+ "warmup_steps": 300,
24
+ "train_split": 0.9,
25
+ "max_seq_length": 64,
26
+ "seed": 42,
27
+ "resume_from_checkpoint": "auto",
28
+ "checkpoint_steps": 1000,
29
+ "save_total_limit": 3
30
+ },
31
+ "export": {
32
+ "enabled": true,
33
+ "required": false,
34
+ "output": "{save_dir}/exports/anime_filename_parser.onnx",
35
+ "max_length": "{max_seq_length}"
36
+ },
37
+ "smoke": {
38
+ "enabled": true,
39
+ "required": true,
40
+ "sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub"
41
+ }
42
+ }
colab/start_worker.ipynb ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "# AniFileBERT Colab Worker\n",
24
+ "\n",
25
+ "Run the next cell in a GPU runtime. Keep it running while Codex submits training jobs. If free Colab disconnects, open the notebook again and rerun the cell; default profiles resume from the latest Drive checkpoint."
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "from google.colab import drive\n",
35
+ "drive.mount('/content/drive')\n",
36
+ "\n",
37
+ "!git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT /content/AniFileBERT || true\n",
38
+ "%cd /content/AniFileBERT\n",
39
+ "!git pull --ff-only || true\n",
40
+ "!git submodule update --init --recursive\n",
41
+ "!python colab_worker.py\n"
42
+ ]
43
+ }
44
+ ]
45
+ }
colab_client.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Local client for controlling an active AniFileBERT Colab worker.
3
+
4
+ The worker still has to be started manually in Colab, but once it prints a
5
+ public URL and token this client lets Codex submit training jobs, tail logs, and
6
+ inspect status from the local workspace.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ from pathlib import Path
15
+ import sys
16
+ import time
17
+ from typing import Any
18
+ import urllib.error
19
+ import urllib.parse
20
+ import urllib.request
21
+
22
+
23
+ TERMINAL_STATES = {"success", "failed", "cancelled"}
24
+
25
+
26
+ def load_json(path: str) -> Any:
27
+ return json.loads(Path(path).read_text(encoding="utf-8"))
28
+
29
+
30
+ class ColabClient:
31
+ def __init__(self, base_url: str, token: str, timeout: int = 30):
32
+ self.base_url = base_url.rstrip("/")
33
+ self.token = token
34
+ self.timeout = timeout
35
+
36
+ def request(self, method: str, path: str, payload: Any | None = None) -> Any:
37
+ url = self.base_url + path
38
+ data = None
39
+ headers = {"Authorization": f"Bearer {self.token}"}
40
+ if payload is not None:
41
+ data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
42
+ headers["Content-Type"] = "application/json; charset=utf-8"
43
+
44
+ req = urllib.request.Request(url, data=data, headers=headers, method=method)
45
+ try:
46
+ with urllib.request.urlopen(req, timeout=self.timeout) as response:
47
+ return json.loads(response.read().decode("utf-8"))
48
+ except urllib.error.HTTPError as exc:
49
+ body = exc.read().decode("utf-8", errors="replace")
50
+ raise RuntimeError(f"{method} {url} failed: HTTP {exc.code}: {body}") from exc
51
+
52
+ def health(self) -> Any:
53
+ return self.request("GET", "/health")
54
+
55
+ def submit(self, payload: dict[str, Any]) -> Any:
56
+ return self.request("POST", "/jobs", payload)
57
+
58
+ def jobs(self) -> Any:
59
+ return self.request("GET", "/jobs")
60
+
61
+ def status(self, job_id: str) -> Any:
62
+ return self.request("GET", f"/jobs/{job_id}")
63
+
64
+ def logs(self, job_id: str, tail: int) -> Any:
65
+ query = urllib.parse.urlencode({"tail": tail})
66
+ return self.request("GET", f"/jobs/{job_id}/logs?{query}")
67
+
68
+ def manifest(self, job_id: str) -> Any:
69
+ return self.request("GET", f"/jobs/{job_id}/manifest")
70
+
71
+ def cancel(self, job_id: str) -> Any:
72
+ return self.request("POST", f"/jobs/{job_id}/cancel", {})
73
+
74
+
75
+ def print_json(data: Any) -> None:
76
+ print(json.dumps(data, ensure_ascii=False, indent=2))
77
+
78
+
79
+ def require_connection(args: argparse.Namespace) -> ColabClient:
80
+ url = args.url or os.environ.get("ANIFILEBERT_COLAB_URL")
81
+ token = args.token or os.environ.get("ANIFILEBERT_COLAB_TOKEN")
82
+ if not url or not token:
83
+ raise SystemExit(
84
+ "Set ANIFILEBERT_COLAB_URL and ANIFILEBERT_COLAB_TOKEN, "
85
+ "or pass --url and --token."
86
+ )
87
+ return ColabClient(url, token, timeout=args.timeout)
88
+
89
+
90
+ def build_submit_payload(args: argparse.Namespace) -> dict[str, Any]:
91
+ payload: dict[str, Any] = {}
92
+ if args.config:
93
+ payload["config"] = load_json(args.config)
94
+ if args.profile:
95
+ payload["profile"] = args.profile
96
+ extra_args = list(args.args or []) + list(args.extra_args or [])
97
+ if extra_args:
98
+ payload["args"] = extra_args
99
+ if not payload:
100
+ payload["profile"] = "dmhy_regex_finetune"
101
+ return payload
102
+
103
+
104
+ def wait_for_job(client: ColabClient, job_id: str, poll: int, tail: int) -> dict[str, Any]:
105
+ last_status = None
106
+ while True:
107
+ status = client.status(job_id)
108
+ if status.get("status") != last_status:
109
+ print_json(status)
110
+ last_status = status.get("status")
111
+ logs = client.logs(job_id, tail=tail)
112
+ log_text = logs.get("log", "")
113
+ if log_text:
114
+ print("\n--- log tail ---")
115
+ print(log_text.rstrip())
116
+ if status.get("status") in TERMINAL_STATES:
117
+ return status
118
+ time.sleep(poll)
119
+
120
+
121
+ def parse_args() -> argparse.Namespace:
122
+ parser = argparse.ArgumentParser(description="Control an active AniFileBERT Colab worker")
123
+ parser.add_argument("--url", help="Worker URL, or ANIFILEBERT_COLAB_URL")
124
+ parser.add_argument("--token", help="Worker token, or ANIFILEBERT_COLAB_TOKEN")
125
+ parser.add_argument("--timeout", type=int, default=30)
126
+
127
+ subparsers = parser.add_subparsers(dest="command", required=True)
128
+
129
+ subparsers.add_parser("health", help="Check worker health")
130
+ subparsers.add_parser("jobs", help="List known jobs")
131
+
132
+ submit = subparsers.add_parser("submit", help="Submit a training job")
133
+ submit.add_argument("--config", help="Local JSON config to send to the worker")
134
+ submit.add_argument("--profile", help="Remote profile name under colab/configs")
135
+ submit.add_argument("--arg", dest="args", action="append", default=[], help="Extra arg for colab_train.py")
136
+ submit.add_argument("--wait", action="store_true", help="Poll until the job finishes")
137
+ submit.add_argument("--poll", type=int, default=60, help="Polling interval in seconds")
138
+ submit.add_argument("--tail", type=int, default=80, help="Log lines to show while waiting")
139
+ submit.add_argument("extra_args", nargs=argparse.REMAINDER,
140
+ help="Arguments after -- are passed to colab_train.py")
141
+
142
+ status = subparsers.add_parser("status", help="Show job status")
143
+ status.add_argument("job_id")
144
+
145
+ logs = subparsers.add_parser("logs", help="Show job logs")
146
+ logs.add_argument("job_id")
147
+ logs.add_argument("--tail", type=int, default=200)
148
+
149
+ manifest = subparsers.add_parser("manifest", help="Show job manifest")
150
+ manifest.add_argument("job_id")
151
+
152
+ cancel = subparsers.add_parser("cancel", help="Cancel a running job")
153
+ cancel.add_argument("job_id")
154
+
155
+ return parser.parse_args()
156
+
157
+
158
+ def main() -> None:
159
+ args = parse_args()
160
+ client = require_connection(args)
161
+
162
+ if args.command == "health":
163
+ print_json(client.health())
164
+ elif args.command == "jobs":
165
+ print_json(client.jobs())
166
+ elif args.command == "submit":
167
+ job = client.submit(build_submit_payload(args))
168
+ print_json(job)
169
+ if args.wait:
170
+ final_status = wait_for_job(client, job["job_id"], poll=args.poll, tail=args.tail)
171
+ if final_status.get("status") != "success":
172
+ sys.exit(1)
173
+ elif args.command == "status":
174
+ print_json(client.status(args.job_id))
175
+ elif args.command == "logs":
176
+ print(client.logs(args.job_id, args.tail).get("log", ""), end="")
177
+ elif args.command == "manifest":
178
+ print_json(client.manifest(args.job_id))
179
+ elif args.command == "cancel":
180
+ print_json(client.cancel(args.job_id))
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
colab_train.py CHANGED
@@ -1,139 +1,543 @@
1
  # -*- coding: utf-8 -*-
2
- """AniFileBERT Google Colab Training Script
3
- =============================================
4
-
5
- How to use:
6
- 1. Open https://colab.research.google.com/
7
- 2. File → Upload notebook → select this file, OR
8
- Copy the entire content into a new code cell
9
- 3. Runtime Change runtime type T4 GPU
10
- 4. Run all
11
-
12
- What it does:
13
- - Mounts Google Drive (for persistent checkpoints)
14
- - Clones AniFileBERT repo + AnimeName dataset submodule
15
- - Installs PyTorch + Transformers dependencies
16
- - Runs training: train a character-token model with the full DMHY vocab
17
- - Saves final model to Drive
18
-
19
- Output:
20
- - Checkpoints saved to: MyDrive/AniFileBERT/checkpoints/
21
- - Final model at: MyDrive/AniFileBERT/checkpoints/dmhy-weak-char/final/
22
  """
23
 
 
 
 
 
 
 
24
  import os
25
- import sys
 
 
26
  import subprocess
27
- import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def run(cmd, echo=True):
31
- """Run a shell command and print output in real time."""
32
- if echo:
33
- print(f"\n$ {cmd}")
34
  proc = subprocess.Popen(
35
- cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
36
- text=True, bufsize=1
 
 
 
 
 
 
 
37
  )
 
38
  for line in proc.stdout:
39
  print(line, end="")
40
  proc.wait()
41
- if proc.returncode != 0:
42
- raise RuntimeError(f"Command failed (exit code {proc.returncode}): {cmd}")
 
 
43
  return proc.returncode
44
 
45
 
46
- # ── 1. Mount Google Drive ──────────────────────────────────────
47
- print("=" * 60)
48
- print("STEP 1: Mount Google Drive")
49
- print("=" * 60)
50
- from google.colab import drive
51
- drive.mount("/content/drive")
52
-
53
- DRIVE_ROOT = "/content/drive/MyDrive/AniFileBERT"
54
- os.makedirs(DRIVE_ROOT, exist_ok=True)
55
- print(f"Checkpoints will be saved to: {DRIVE_ROOT}")
56
-
57
- # ── 2. Clone repositories ──────────────────────────────────────
58
- print("\n" + "=" * 60)
59
- print("STEP 2: Clone AniFileBERT repository")
60
- print("=" * 60)
61
-
62
- REPO_DIR = "/content/AniFileBERT"
63
- if not os.path.isdir(REPO_DIR):
64
- os.chdir("/content")
65
- run("git clone --recursive https://huggingface.co/ModerRAS/AniFileBERT")
66
- else:
67
- print("Repository already exists, pulling latest...")
68
- os.chdir(REPO_DIR)
69
- run("git pull")
70
- run("git submodule update --init --recursive")
71
-
72
- os.chdir(REPO_DIR)
73
-
74
- # ── 3. Install dependencies ────────────────────────────────────
75
- print("\n" + "=" * 60)
76
- print("STEP 3: Install dependencies")
77
- print("=" * 60)
78
- # Colab comes with PyTorch + CUDA pre-installed. Just install the extras.
79
- run("pip install transformers accelerate seqeval onnx onnxruntime onnxscript")
80
-
81
- # ── 4. Verify GPU ──────────────────────────────────────────────
82
- print("\n" + "=" * 60)
83
- print("STEP 4: Verify GPU")
84
- print("=" * 60)
85
- run("nvidia-smi 2>/dev/null || echo 'No GPU found — training will be slow on CPU'")
86
- # Single-quote the shell command to avoid bash expanding {torch...}
87
- run("python -c 'import torch; print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")'")
88
-
89
- # ── 5. Verify vocab ────────────────────────────────────────────
90
- print("\n" + "=" * 60)
91
- print("STEP 5: Verify vocabulary")
92
- print("=" * 60)
93
- run("python -c 'import json; v=json.load(open(\"vocab.char.json\", encoding=\"utf-8\")); print(f\"Character vocab size: {len(v)} tokens\")'")
94
-
95
- # ── 6. Run training ────────────────────────────────────────────
96
- print("\n" + "=" * 60)
97
- print("STEP 6: Train model")
98
- print("=" * 60)
99
-
100
- # The full DMHY character vocab is only 6199 tokens and covers every character
101
- # occurrence in dmhy_weak_char.jsonl.
102
- SAVE_DIR = os.path.join(DRIVE_ROOT, "checkpoints", "dmhy-weak-char")
103
-
104
- run(
105
- f"python train.py "
106
- f"--tokenizer char "
107
- f"--data-file datasets/AnimeName/dmhy_weak_char.jsonl "
108
- f"--vocab-file vocab.char.json "
109
- f"--save-dir {SAVE_DIR} "
110
- f"--epochs 5 --batch-size 128 "
111
- f"--learning-rate 0.0003 --warmup-steps 300 "
112
- f"--max-seq-length 128 "
113
- f"--seed 42 "
114
- f"--no-shuffle"
115
- )
116
-
117
- # ── 7. Export ONNX (optional) ──────────────────────────────────
118
- print("\n" + "=" * 60)
119
- print("STEP 7: Export ONNX (optional — skip if it fails)")
120
- print("=" * 60)
121
- ONNX_OUT = os.path.join(SAVE_DIR, "..", "anime_filename_parser.onnx")
122
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  run(
124
- f"python export_onnx.py "
125
- f"--model-dir {SAVE_DIR}/final "
126
- f"--output {ONNX_OUT}"
 
 
 
 
 
127
  )
128
- except Exception as e:
129
- print(f"[WARN] ONNX export skipped: {e}")
130
-
131
- # ── 8. Summary ─────────────────────────────────────────────────
132
- print("\n" + "=" * 60)
133
- print("DONE!")
134
- print("=" * 60)
135
- print(f"\nCheckpoints: {SAVE_DIR}/")
136
- print(f"Final model: {SAVE_DIR}/final/")
137
- print(f"ONNX export: {ONNX_OUT}")
138
- print(f"\nAll files are on Google Drive — they persist across Colab sessions.")
139
- print(f"You can also download them from the Drive web UI.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
+ """Codex-friendly Google Colab runner for AniFileBERT training.
3
+
4
+ Typical Colab usage:
5
+
6
+ python colab_train.py --config colab/configs/dmhy_regex_finetune.json
7
+
8
+ This script keeps the Colab side reproducible by putting run parameters in JSON
9
+ profiles. It can clone/update the repo, mount Drive, install dependencies,
10
+ train, optionally export ONNX, run an inference smoke check, and write a run
11
+ manifest that Codex can inspect later.
 
 
 
 
 
 
 
 
 
 
12
  """
13
 
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import copy
18
+ import datetime as dt
19
+ import json
20
  import os
21
+ from pathlib import Path
22
+ import shlex
23
+ import shutil
24
  import subprocess
25
+ import sys
26
+ import traceback
27
+ from typing import Any, Mapping, Sequence
28
+ import urllib.request
29
+
30
+
31
+ DEFAULT_CONFIG: dict[str, Any] = {
32
+ "name": "dmhy-regex-finetune",
33
+ "repo_url": "https://huggingface.co/ModerRAS/AniFileBERT",
34
+ "repo_ref": "main",
35
+ "repo_dir": "/content/AniFileBERT",
36
+ "drive_root": "/content/drive/MyDrive/AniFileBERT",
37
+ "mount_drive": True,
38
+ "pull": True,
39
+ "install": {
40
+ "requirements": True,
41
+ "git_lfs": True,
42
+ "extra_packages": [],
43
+ },
44
+ "training": {
45
+ "tokenizer": "regex",
46
+ "data_file": "datasets/AnimeName/dmhy_weak.jsonl",
47
+ "vocab_file": "datasets/AnimeName/vocab.json",
48
+ "save_dir": "{drive_root}/checkpoints/{name}",
49
+ "init_model_dir": ".",
50
+ "epochs": 1,
51
+ "batch_size": 128,
52
+ "learning_rate": 0.0003,
53
+ "warmup_steps": 300,
54
+ "train_split": 0.9,
55
+ "max_seq_length": 64,
56
+ "seed": 42,
57
+ "limit_samples": None,
58
+ "rebuild_vocab": False,
59
+ "max_vocab_size": None,
60
+ "resume_from_checkpoint": "auto",
61
+ "checkpoint_steps": 1000,
62
+ "save_total_limit": 3,
63
+ "cpu": False,
64
+ "no_shuffle": False,
65
+ "extra_args": [],
66
+ },
67
+ "export": {
68
+ "enabled": True,
69
+ "required": False,
70
+ "output": "{save_dir}/exports/anime_filename_parser.onnx",
71
+ "max_length": "{max_seq_length}",
72
+ "sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub",
73
+ "android_assets_dir": None,
74
+ },
75
+ "smoke": {
76
+ "enabled": True,
77
+ "required": True,
78
+ "sample": "Witch.Hat.Atelier.S01E07.1080p.NF.WEB-DL.JPN.AAC2.0.H.264.MSubs-ToonsHub",
79
+ },
80
+ "artifacts": {
81
+ "manifest": "{save_dir}/colab_run_manifest.json",
82
+ "latest_manifest": "{drive_root}/last_run_manifest.json",
83
+ },
84
+ }
85
+
86
+
87
+ COMMAND_LOG: list[dict[str, Any]] = []
88
+
89
+
90
+ class SafeFormatDict(dict):
91
+ def __missing__(self, key: str) -> str:
92
+ return "{" + key + "}"
93
+
94
+
95
+ def utc_now() -> str:
96
+ return dt.datetime.now(dt.timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
97
+
98
+
99
+ def deep_merge(base: Mapping[str, Any], override: Mapping[str, Any]) -> dict[str, Any]:
100
+ merged = copy.deepcopy(dict(base))
101
+ for key, value in override.items():
102
+ if isinstance(value, Mapping) and isinstance(merged.get(key), Mapping):
103
+ merged[key] = deep_merge(merged[key], value)
104
+ else:
105
+ merged[key] = copy.deepcopy(value)
106
+ return merged
107
+
108
+
109
+ def render_templates(value: Any, context: Mapping[str, Any]) -> Any:
110
+ if isinstance(value, str):
111
+ return value.format_map(SafeFormatDict(context))
112
+ if isinstance(value, list):
113
+ return [render_templates(item, context) for item in value]
114
+ if isinstance(value, dict):
115
+ return {key: render_templates(item, context) for key, item in value.items()}
116
+ return value
117
+
118
 
119
+ def command_text(args: str | Sequence[Any]) -> str:
120
+ if isinstance(args, str):
121
+ return args
122
+ return " ".join(shlex.quote(str(arg)) for arg in args)
123
+
124
+
125
+ def run(
126
+ args: str | Sequence[Any],
127
+ *,
128
+ cwd: str | os.PathLike[str] | None = None,
129
+ check: bool = True,
130
+ dry_run: bool = False,
131
+ ) -> int:
132
+ text = command_text(args)
133
+ entry: dict[str, Any] = {
134
+ "cmd": text,
135
+ "cwd": os.fspath(cwd) if cwd is not None else None,
136
+ "started_at": utc_now(),
137
+ "dry_run": dry_run,
138
+ }
139
+ COMMAND_LOG.append(entry)
140
+ print(f"\n$ {text}")
141
+ if dry_run:
142
+ entry["returncode"] = 0
143
+ entry["finished_at"] = utc_now()
144
+ return 0
145
 
 
 
 
 
146
  proc = subprocess.Popen(
147
+ args,
148
+ cwd=cwd,
149
+ shell=isinstance(args, str),
150
+ stdout=subprocess.PIPE,
151
+ stderr=subprocess.STDOUT,
152
+ text=True,
153
+ encoding="utf-8",
154
+ errors="replace",
155
+ bufsize=1,
156
  )
157
+ assert proc.stdout is not None
158
  for line in proc.stdout:
159
  print(line, end="")
160
  proc.wait()
161
+ entry["returncode"] = proc.returncode
162
+ entry["finished_at"] = utc_now()
163
+ if check and proc.returncode != 0:
164
+ raise RuntimeError(f"Command failed with exit code {proc.returncode}: {text}")
165
  return proc.returncode
166
 
167
 
168
+ def parse_args() -> argparse.Namespace:
169
+ parser = argparse.ArgumentParser(description="Run AniFileBERT training in Colab")
170
+ parser.add_argument("--config", help="JSON profile path or URL")
171
+ parser.add_argument("--profile", help="Profile name under colab/configs without .json")
172
+ parser.add_argument("--repo-url", help="Override repository URL")
173
+ parser.add_argument("--repo-ref", help="Override branch, tag, or commit to checkout")
174
+ parser.add_argument("--repo-dir", help="Override Colab repository directory")
175
+ parser.add_argument("--drive-root", help="Override Google Drive output root")
176
+ parser.add_argument("--save-dir", help="Override checkpoint output directory")
177
+ parser.add_argument("--epochs", type=float, help="Override training epochs")
178
+ parser.add_argument("--batch-size", type=int, help="Override per-device batch size")
179
+ parser.add_argument("--learning-rate", type=float, help="Override learning rate")
180
+ parser.add_argument("--warmup-steps", type=int, help="Override warmup steps")
181
+ parser.add_argument("--limit-samples", type=int, help="Use only the first N dataset rows")
182
+ parser.add_argument("--skip-install", action="store_true", help="Do not install pip or git-lfs dependencies")
183
+ parser.add_argument("--skip-export", action="store_true", help="Do not run ONNX export")
184
+ parser.add_argument("--skip-smoke", action="store_true", help="Do not run inference smoke check")
185
+ parser.add_argument("--no-mount-drive", action="store_true", help="Do not mount Google Drive")
186
+ parser.add_argument("--no-pull", action="store_true", help="Do not pull an existing checkout")
187
+ parser.add_argument("--dry-run", action="store_true", help="Print commands and write no training outputs")
188
+ parser.add_argument("--print-config", action="store_true", help="Print resolved config before running")
189
+ return parser.parse_args()
190
+
191
+
192
+ def load_json_source(source: str | None, *, required: bool) -> dict[str, Any]:
193
+ if not source:
194
+ return {}
195
+ if source.startswith(("http://", "https://")):
196
+ with urllib.request.urlopen(source) as response:
197
+ return json.loads(response.read().decode("utf-8"))
198
+
199
+ candidates = [Path(source), Path(__file__).resolve().parent / source]
200
+ for candidate in candidates:
201
+ if candidate.is_file():
202
+ return json.loads(candidate.read_text(encoding="utf-8"))
203
+ if required:
204
+ raise FileNotFoundError(f"Config file not found: {source}")
205
+ return {}
206
+
207
+
208
+ def load_config(args: argparse.Namespace) -> dict[str, Any]:
209
+ config_source = args.config
210
+ required = bool(args.config)
211
+ if config_source is None and args.profile:
212
+ config_source = os.fspath(Path("colab") / "configs" / f"{args.profile}.json")
213
+ required = True
214
+
215
+ profile_config = load_json_source(config_source, required=required)
216
+ config = deep_merge(DEFAULT_CONFIG, profile_config)
217
+
218
+ if args.repo_url:
219
+ config["repo_url"] = args.repo_url
220
+ if args.repo_ref:
221
+ config["repo_ref"] = args.repo_ref
222
+ if args.repo_dir:
223
+ config["repo_dir"] = args.repo_dir
224
+ if args.drive_root:
225
+ config["drive_root"] = args.drive_root
226
+ if args.no_mount_drive:
227
+ config["mount_drive"] = False
228
+ if args.no_pull:
229
+ config["pull"] = False
230
+ if args.skip_install:
231
+ config["install"]["requirements"] = False
232
+ config["install"]["git_lfs"] = False
233
+ config["install"]["extra_packages"] = []
234
+ if args.skip_export:
235
+ config["export"]["enabled"] = False
236
+ if args.skip_smoke:
237
+ config["smoke"]["enabled"] = False
238
+
239
+ training = config["training"]
240
+ for arg_name, key in [
241
+ ("save_dir", "save_dir"),
242
+ ("epochs", "epochs"),
243
+ ("batch_size", "batch_size"),
244
+ ("learning_rate", "learning_rate"),
245
+ ("warmup_steps", "warmup_steps"),
246
+ ("limit_samples", "limit_samples"),
247
+ ]:
248
+ value = getattr(args, arg_name)
249
+ if value is not None:
250
+ training[key] = value
251
+
252
+ return resolve_config(config)
253
+
254
+
255
+ def resolve_config(config: dict[str, Any]) -> dict[str, Any]:
256
+ context: dict[str, Any] = {
257
+ "name": config["name"],
258
+ "repo_url": config["repo_url"],
259
+ "repo_ref": config.get("repo_ref") or "",
260
+ "repo_dir": config["repo_dir"],
261
+ "drive_root": config["drive_root"],
262
+ }
263
+
264
+ training = render_templates(config["training"], context)
265
+ context.update(training)
266
+ if not training.get("save_dir"):
267
+ training["save_dir"] = os.path.join(config["drive_root"], "checkpoints", config["name"])
268
+ training = render_templates(training, {**context, **training})
269
+ context.update(training)
270
+ context["save_dir"] = training["save_dir"]
271
+ context["final_model_dir"] = os.path.join(training["save_dir"], "final")
272
+
273
+ resolved = copy.deepcopy(config)
274
+ resolved["training"] = training
275
+ resolved["export"] = render_templates(config["export"], context)
276
+ resolved["smoke"] = render_templates(config["smoke"], context)
277
+ resolved["artifacts"] = render_templates(config["artifacts"], context)
278
+ return resolved
279
+
280
+
281
+ def maybe_mount_drive(config: Mapping[str, Any]) -> None:
282
+ if not config.get("mount_drive", True):
283
+ print("Google Drive mount disabled.")
284
+ return
285
+ try:
286
+ from google.colab import drive # type: ignore
287
+ except Exception:
288
+ print("[WARN] google.colab is unavailable; skipping Drive mount.")
289
+ return
290
+ print("Mounting Google Drive...")
291
+ drive.mount("/content/drive")
292
+
293
+
294
+ def install_git_lfs_if_needed(config: Mapping[str, Any], *, dry_run: bool) -> None:
295
+ if not config.get("install", {}).get("git_lfs", True):
296
+ return
297
+ if shutil.which("git-lfs"):
298
+ run(["git", "lfs", "install"], check=False, dry_run=dry_run)
299
+ return
300
+ if Path("/content").exists():
301
+ print("Installing git-lfs for Hugging Face model artifacts...")
302
+ run(["apt-get", "update"], check=False, dry_run=dry_run)
303
+ run(["apt-get", "install", "-y", "git-lfs"], dry_run=dry_run)
304
+ run(["git", "lfs", "install"], check=False, dry_run=dry_run)
305
+ else:
306
+ print("[WARN] git-lfs not found. Existing LFS pointers may not contain model weights.")
307
+
308
+
309
+ def is_git_repo(path: Path) -> bool:
310
+ return (path / ".git").exists()
311
+
312
+
313
+ def prepare_repo(config: Mapping[str, Any], *, dry_run: bool) -> Path:
314
+ repo_dir = Path(config["repo_dir"])
315
+ repo_url = config["repo_url"]
316
+ repo_ref = config.get("repo_ref")
317
+
318
+ if not is_git_repo(repo_dir):
319
+ if repo_dir.exists() and any(repo_dir.iterdir()):
320
+ raise RuntimeError(f"{repo_dir} exists but is not a git checkout")
321
+ repo_dir.parent.mkdir(parents=True, exist_ok=True)
322
+ run(["git", "clone", "--recursive", repo_url, os.fspath(repo_dir)], dry_run=dry_run)
323
+ else:
324
+ print(f"Using existing repository checkout: {repo_dir}")
325
+
326
+ if repo_ref:
327
+ run(["git", "fetch", "--all", "--tags"], cwd=repo_dir, check=False, dry_run=dry_run)
328
+ run(["git", "checkout", str(repo_ref)], cwd=repo_dir, dry_run=dry_run)
329
+
330
+ if config.get("pull", True):
331
+ run(["git", "pull", "--ff-only"], cwd=repo_dir, check=False, dry_run=dry_run)
332
+
333
+ run(["git", "submodule", "update", "--init", "--recursive"], cwd=repo_dir, dry_run=dry_run)
334
+ if shutil.which("git-lfs"):
335
+ run(["git", "lfs", "pull"], cwd=repo_dir, check=False, dry_run=dry_run)
336
+
337
+ return repo_dir
338
+
339
+
340
+ def install_python_deps(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
341
+ install = config.get("install", {})
342
+ if install.get("requirements", True):
343
+ run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=repo_dir, dry_run=dry_run)
344
+ for package in install.get("extra_packages", []):
345
+ run([sys.executable, "-m", "pip", "install", str(package)], cwd=repo_dir, dry_run=dry_run)
346
+
347
+
348
+ def verify_runtime(repo_dir: Path, *, dry_run: bool) -> None:
349
+ run(["nvidia-smi"], cwd=repo_dir, check=False, dry_run=dry_run)
350
  run(
351
+ [
352
+ sys.executable,
353
+ "-c",
354
+ "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')",
355
+ ],
356
+ cwd=repo_dir,
357
+ check=False,
358
+ dry_run=dry_run,
359
  )
360
+
361
+
362
+ def add_arg(cmd: list[str], flag: str, value: Any) -> None:
363
+ if value is None or value is False:
364
+ return
365
+ if value is True:
366
+ cmd.append(flag)
367
+ else:
368
+ cmd.extend([flag, str(value)])
369
+
370
+
371
+ def build_train_command(training: Mapping[str, Any]) -> list[str]:
372
+ cmd = [sys.executable, "train.py"]
373
+ for key, flag in [
374
+ ("tokenizer", "--tokenizer"),
375
+ ("data_file", "--data-file"),
376
+ ("vocab_file", "--vocab-file"),
377
+ ("save_dir", "--save-dir"),
378
+ ("init_model_dir", "--init-model-dir"),
379
+ ("epochs", "--epochs"),
380
+ ("batch_size", "--batch-size"),
381
+ ("learning_rate", "--learning-rate"),
382
+ ("warmup_steps", "--warmup-steps"),
383
+ ("train_split", "--train-split"),
384
+ ("max_seq_length", "--max-seq-length"),
385
+ ("seed", "--seed"),
386
+ ("limit_samples", "--limit-samples"),
387
+ ("max_vocab_size", "--max-vocab-size"),
388
+ ("resume_from_checkpoint", "--resume-from-checkpoint"),
389
+ ("checkpoint_steps", "--checkpoint-steps"),
390
+ ("save_total_limit", "--save-total-limit"),
391
+ ]:
392
+ add_arg(cmd, flag, training.get(key))
393
+ add_arg(cmd, "--rebuild-vocab", training.get("rebuild_vocab"))
394
+ add_arg(cmd, "--cpu", training.get("cpu"))
395
+ add_arg(cmd, "--no-shuffle", training.get("no_shuffle"))
396
+ cmd.extend(str(arg) for arg in training.get("extra_args", []))
397
+ return cmd
398
+
399
+
400
+ def run_training(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
401
+ training = config["training"]
402
+ if not dry_run:
403
+ Path(training["save_dir"]).mkdir(parents=True, exist_ok=True)
404
+ run(build_train_command(training), cwd=repo_dir, dry_run=dry_run)
405
+
406
+
407
+ def run_export(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
408
+ export = config["export"]
409
+ if not export.get("enabled", True):
410
+ print("ONNX export disabled.")
411
+ return
412
+ cmd = [
413
+ sys.executable,
414
+ "export_onnx.py",
415
+ "--model-dir",
416
+ os.path.join(config["training"]["save_dir"], "final"),
417
+ "--output",
418
+ export["output"],
419
+ "--max-length",
420
+ str(export["max_length"]),
421
+ ]
422
+ add_arg(cmd, "--sample", export.get("sample"))
423
+ add_arg(cmd, "--android-assets-dir", export.get("android_assets_dir"))
424
+ try:
425
+ run(cmd, cwd=repo_dir, dry_run=dry_run)
426
+ except Exception:
427
+ if export.get("required", False):
428
+ raise
429
+ print("[WARN] ONNX export failed, but export.required is false.")
430
+ traceback.print_exc()
431
+
432
+
433
+ def run_smoke(config: Mapping[str, Any], repo_dir: Path, *, dry_run: bool) -> None:
434
+ smoke = config["smoke"]
435
+ if not smoke.get("enabled", True):
436
+ print("Inference smoke check disabled.")
437
+ return
438
+ cmd = [
439
+ sys.executable,
440
+ "inference.py",
441
+ "--model-dir",
442
+ os.path.join(config["training"]["save_dir"], "final"),
443
+ smoke["sample"],
444
+ ]
445
+ try:
446
+ run(cmd, cwd=repo_dir, dry_run=dry_run)
447
+ except Exception:
448
+ if smoke.get("required", True):
449
+ raise
450
+ print("[WARN] Smoke check failed, but smoke.required is false.")
451
+ traceback.print_exc()
452
+
453
+
454
+ def git_commit(repo_dir: Path, *, dry_run: bool) -> str | None:
455
+ if dry_run:
456
+ return None
457
+ try:
458
+ return subprocess.check_output(
459
+ ["git", "rev-parse", "HEAD"],
460
+ cwd=repo_dir,
461
+ text=True,
462
+ encoding="utf-8",
463
+ errors="replace",
464
+ ).strip()
465
+ except Exception:
466
+ return None
467
+
468
+
469
+ def write_json(path: str | os.PathLike[str], data: Mapping[str, Any], *, dry_run: bool) -> None:
470
+ print(f"Writing manifest: {path}")
471
+ if dry_run:
472
+ return
473
+ output_path = Path(path)
474
+ output_path.parent.mkdir(parents=True, exist_ok=True)
475
+ output_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
476
+
477
+
478
+ def write_manifests(
479
+ config: Mapping[str, Any],
480
+ repo_dir: Path,
481
+ *,
482
+ status: str,
483
+ started_at: str,
484
+ error: str | None,
485
+ dry_run: bool,
486
+ ) -> None:
487
+ save_dir = config["training"]["save_dir"]
488
+ manifest = {
489
+ "status": status,
490
+ "name": config["name"],
491
+ "started_at": started_at,
492
+ "finished_at": utc_now(),
493
+ "repo_url": config["repo_url"],
494
+ "repo_ref": config.get("repo_ref"),
495
+ "repo_commit": git_commit(repo_dir, dry_run=dry_run),
496
+ "repo_dir": os.fspath(repo_dir),
497
+ "save_dir": save_dir,
498
+ "final_model_dir": os.path.join(save_dir, "final"),
499
+ "onnx_output": config["export"].get("output") if config["export"].get("enabled") else None,
500
+ "config": config,
501
+ "commands": COMMAND_LOG,
502
+ "error": error,
503
+ }
504
+ artifacts = config["artifacts"]
505
+ write_json(artifacts["manifest"], manifest, dry_run=dry_run)
506
+ if artifacts.get("latest_manifest"):
507
+ write_json(artifacts["latest_manifest"], manifest, dry_run=dry_run)
508
+
509
+
510
+ def main() -> None:
511
+ args = parse_args()
512
+ started_at = utc_now()
513
+ config = load_config(args)
514
+
515
+ if args.print_config:
516
+ print(json.dumps(config, ensure_ascii=False, indent=2))
517
+
518
+ repo_dir = Path(config["repo_dir"])
519
+ status = "failed"
520
+ error: str | None = None
521
+ try:
522
+ maybe_mount_drive(config)
523
+ install_git_lfs_if_needed(config, dry_run=args.dry_run)
524
+ repo_dir = prepare_repo(config, dry_run=args.dry_run)
525
+ install_python_deps(config, repo_dir, dry_run=args.dry_run)
526
+ verify_runtime(repo_dir, dry_run=args.dry_run)
527
+ run_training(config, repo_dir, dry_run=args.dry_run)
528
+ run_export(config, repo_dir, dry_run=args.dry_run)
529
+ run_smoke(config, repo_dir, dry_run=args.dry_run)
530
+ status = "success"
531
+ except Exception as exc:
532
+ error = f"{type(exc).__name__}: {exc}"
533
+ raise
534
+ finally:
535
+ write_manifests(config, repo_dir, status=status, started_at=started_at, error=error, dry_run=args.dry_run)
536
+
537
+ print("\nDone.")
538
+ print(f"Final model: {os.path.join(config['training']['save_dir'], 'final')}")
539
+ print(f"Manifest: {config['artifacts']['manifest']}")
540
+
541
+
542
+ if __name__ == "__main__":
543
+ main()
colab_worker.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Small HTTP worker for running AniFileBERT training jobs on Google Colab.
3
+
4
+ Start this inside a Colab runtime:
5
+
6
+ python colab_worker.py
7
+
8
+ The worker exposes a token-protected local HTTP API and, by default, starts a
9
+ Cloudflare Quick Tunnel so Codex on your local machine can submit jobs.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ from pathlib import Path
18
+ import platform
19
+ import re
20
+ import secrets
21
+ import shutil
22
+ import signal
23
+ import subprocess
24
+ import sys
25
+ import threading
26
+ import time
27
+ import traceback
28
+ from http import HTTPStatus
29
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
30
+ from typing import Any
31
+ from urllib.parse import parse_qs, urlparse
32
+ import urllib.request
33
+
34
+
35
+ TERMINAL_STATES = {"success", "failed", "cancelled"}
36
+ TUNNEL_URL_RE = re.compile(r"https://[-a-zA-Z0-9.]+\.trycloudflare\.com")
37
+
38
+
39
+ def utc_timestamp() -> str:
40
+ return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
41
+
42
+
43
+ def json_dumps(data: Any) -> str:
44
+ return json.dumps(data, ensure_ascii=False, indent=2)
45
+
46
+
47
+ def read_tail(path: Path, lines: int) -> str:
48
+ if not path.is_file():
49
+ return ""
50
+ if lines <= 0:
51
+ return path.read_text(encoding="utf-8", errors="replace")
52
+
53
+ chunk_size = 8192
54
+ data = b""
55
+ with path.open("rb") as f:
56
+ f.seek(0, os.SEEK_END)
57
+ pos = f.tell()
58
+ while pos > 0 and data.count(b"\n") <= lines:
59
+ read_size = min(chunk_size, pos)
60
+ pos -= read_size
61
+ f.seek(pos)
62
+ data = f.read(read_size) + data
63
+ return b"\n".join(data.splitlines()[-lines:]).decode("utf-8", errors="replace")
64
+
65
+
66
+ def download_cloudflared(path: Path) -> Path:
67
+ if path.is_file():
68
+ return path
69
+
70
+ existing = shutil.which("cloudflared")
71
+ if existing:
72
+ return Path(existing)
73
+
74
+ arch = platform.machine().lower()
75
+ if arch in {"x86_64", "amd64"}:
76
+ suffix = "linux-amd64"
77
+ elif arch in {"aarch64", "arm64"}:
78
+ suffix = "linux-arm64"
79
+ else:
80
+ raise RuntimeError(f"Unsupported CPU architecture for cloudflared: {arch}")
81
+
82
+ url = f"https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-{suffix}"
83
+ print(f"Downloading cloudflared: {url}", flush=True)
84
+ path.parent.mkdir(parents=True, exist_ok=True)
85
+ urllib.request.urlretrieve(url, path)
86
+ path.chmod(0o755)
87
+ return path
88
+
89
+
90
+ class WorkerState:
91
+ def __init__(self, repo_dir: Path, jobs_dir: Path):
92
+ self.repo_dir = repo_dir
93
+ self.jobs_dir = jobs_dir
94
+ self.jobs_dir.mkdir(parents=True, exist_ok=True)
95
+ self.jobs: dict[str, dict[str, Any]] = {}
96
+ self.lock = threading.RLock()
97
+
98
+ def list_jobs(self) -> list[dict[str, Any]]:
99
+ with self.lock:
100
+ return [self._public_job(job) for job in self.jobs.values()]
101
+
102
+ def get_job(self, job_id: str) -> dict[str, Any] | None:
103
+ with self.lock:
104
+ job = self.jobs.get(job_id)
105
+ return self._public_job(job) if job else None
106
+
107
+ def get_job_internal(self, job_id: str) -> dict[str, Any] | None:
108
+ with self.lock:
109
+ return self.jobs.get(job_id)
110
+
111
+ def active_job(self) -> dict[str, Any] | None:
112
+ with self.lock:
113
+ for job in self.jobs.values():
114
+ if job["status"] not in TERMINAL_STATES:
115
+ return job
116
+ return None
117
+
118
+ def start_job(self, payload: dict[str, Any]) -> dict[str, Any]:
119
+ with self.lock:
120
+ active = self.active_job()
121
+ if active is not None:
122
+ raise RuntimeError(f"Job already running: {active['job_id']}")
123
+
124
+ job_id = time.strftime("%Y%m%d-%H%M%S", time.gmtime()) + "-" + secrets.token_hex(3)
125
+ job_dir = self.jobs_dir / job_id
126
+ job_dir.mkdir(parents=True, exist_ok=True)
127
+ log_path = job_dir / "worker.log"
128
+ config_path: Path | None = None
129
+
130
+ cmd = [sys.executable, "colab_train.py"]
131
+ config = self._job_config(payload)
132
+ config.setdefault("artifacts", {})
133
+ config["artifacts"]["manifest"] = os.fspath(job_dir / "colab_run_manifest.json")
134
+ config_path = job_dir / "config.json"
135
+ config_path.write_text(json_dumps(config), encoding="utf-8")
136
+ cmd.extend(["--config", os.fspath(config_path)])
137
+
138
+ for arg in payload.get("args", []):
139
+ cmd.append(str(arg))
140
+
141
+ job = {
142
+ "job_id": job_id,
143
+ "status": "queued",
144
+ "created_at": utc_timestamp(),
145
+ "started_at": None,
146
+ "finished_at": None,
147
+ "returncode": None,
148
+ "cmd": cmd,
149
+ "cwd": os.fspath(self.repo_dir),
150
+ "job_dir": os.fspath(job_dir),
151
+ "log_path": os.fspath(log_path),
152
+ "config_path": os.fspath(config_path) if config_path else None,
153
+ "error": None,
154
+ "process": None,
155
+ }
156
+ self.jobs[job_id] = job
157
+
158
+ thread = threading.Thread(target=self._run_job, args=(job_id,), daemon=True)
159
+ thread.start()
160
+ return self._public_job(job)
161
+
162
+ def _job_config(self, payload: dict[str, Any]) -> dict[str, Any]:
163
+ if "config" in payload:
164
+ return json.loads(json.dumps(payload["config"], ensure_ascii=False))
165
+
166
+ profile = str(payload.get("profile", "dmhy_regex_finetune"))
167
+ profile_path = self.repo_dir / "colab" / "configs" / f"{profile}.json"
168
+ if not profile_path.is_file():
169
+ raise FileNotFoundError(f"Profile not found: {profile_path}")
170
+ return json.loads(profile_path.read_text(encoding="utf-8"))
171
+
172
+ def cancel_job(self, job_id: str) -> dict[str, Any]:
173
+ with self.lock:
174
+ job = self.jobs.get(job_id)
175
+ if job is None:
176
+ raise KeyError(job_id)
177
+ process: subprocess.Popen[str] | None = job.get("process")
178
+ if job["status"] in TERMINAL_STATES:
179
+ return self._public_job(job)
180
+ job["status"] = "cancelled"
181
+ job["finished_at"] = utc_timestamp()
182
+
183
+ if process and process.poll() is None:
184
+ try:
185
+ os.killpg(os.getpgid(process.pid), signal.SIGTERM)
186
+ except Exception:
187
+ process.terminate()
188
+ return self.get_job(job_id) or {}
189
+
190
+ def _run_job(self, job_id: str) -> None:
191
+ job = self.get_job_internal(job_id)
192
+ if job is None:
193
+ return
194
+ log_path = Path(job["log_path"])
195
+ try:
196
+ with self.lock:
197
+ job["status"] = "running"
198
+ job["started_at"] = utc_timestamp()
199
+
200
+ with log_path.open("w", encoding="utf-8", errors="replace") as log:
201
+ log.write(f"job_id={job_id}\n")
202
+ log.write(f"cwd={job['cwd']}\n")
203
+ log.write("$ " + " ".join(job["cmd"]) + "\n\n")
204
+ log.flush()
205
+
206
+ process = subprocess.Popen(
207
+ job["cmd"],
208
+ cwd=job["cwd"],
209
+ stdout=subprocess.PIPE,
210
+ stderr=subprocess.STDOUT,
211
+ text=True,
212
+ encoding="utf-8",
213
+ errors="replace",
214
+ bufsize=1,
215
+ preexec_fn=os.setsid if hasattr(os, "setsid") else None,
216
+ )
217
+ with self.lock:
218
+ job["process"] = process
219
+
220
+ assert process.stdout is not None
221
+ for line in process.stdout:
222
+ log.write(line)
223
+ log.flush()
224
+ print(line, end="", flush=True)
225
+ process.wait()
226
+
227
+ with self.lock:
228
+ job["returncode"] = process.returncode
229
+ if job["status"] != "cancelled":
230
+ job["status"] = "success" if process.returncode == 0 else "failed"
231
+ job["finished_at"] = utc_timestamp()
232
+ job["process"] = None
233
+ except Exception as exc:
234
+ with log_path.open("a", encoding="utf-8", errors="replace") as log:
235
+ traceback.print_exc(file=log)
236
+ with self.lock:
237
+ job["status"] = "failed"
238
+ job["finished_at"] = utc_timestamp()
239
+ job["error"] = f"{type(exc).__name__}: {exc}"
240
+ job["process"] = None
241
+
242
+ def _public_job(self, job: dict[str, Any]) -> dict[str, Any]:
243
+ public = {key: value for key, value in job.items() if key != "process"}
244
+ return public
245
+
246
+
247
+ def make_handler(state: WorkerState, token: str):
248
+ class Handler(BaseHTTPRequestHandler):
249
+ server_version = "AniFileBERTColabWorker/1.0"
250
+
251
+ def log_message(self, fmt: str, *args: Any) -> None:
252
+ print(f"[{utc_timestamp()}] {self.address_string()} {fmt % args}", flush=True)
253
+
254
+ def do_GET(self) -> None:
255
+ self._handle("GET")
256
+
257
+ def do_POST(self) -> None:
258
+ self._handle("POST")
259
+
260
+ def _handle(self, method: str) -> None:
261
+ parsed = urlparse(self.path)
262
+ path = parsed.path.rstrip("/") or "/"
263
+ parts = [part for part in path.split("/") if part]
264
+ try:
265
+ if not self._authorized():
266
+ self._send({"error": "unauthorized"}, HTTPStatus.UNAUTHORIZED)
267
+ return
268
+
269
+ if method == "GET" and path == "/health":
270
+ self._send(
271
+ {
272
+ "ok": True,
273
+ "repo_dir": os.fspath(state.repo_dir),
274
+ "jobs_dir": os.fspath(state.jobs_dir),
275
+ "active_job": state.active_job()["job_id"] if state.active_job() else None,
276
+ }
277
+ )
278
+ return
279
+
280
+ if method == "GET" and path == "/jobs":
281
+ self._send({"jobs": state.list_jobs()})
282
+ return
283
+
284
+ if method == "POST" and path == "/jobs":
285
+ payload = self._read_json()
286
+ job = state.start_job(payload)
287
+ self._send(job, HTTPStatus.ACCEPTED)
288
+ return
289
+
290
+ if len(parts) >= 2 and parts[0] == "jobs":
291
+ job_id = parts[1]
292
+ if method == "GET" and len(parts) == 2:
293
+ job = state.get_job(job_id)
294
+ if job is None:
295
+ self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
296
+ else:
297
+ self._send(job)
298
+ return
299
+
300
+ if method == "GET" and len(parts) == 3 and parts[2] == "logs":
301
+ query = parse_qs(parsed.query)
302
+ tail = int(query.get("tail", ["200"])[0])
303
+ job = state.get_job_internal(job_id)
304
+ if job is None:
305
+ self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
306
+ else:
307
+ self._send({"job_id": job_id, "log": read_tail(Path(job["log_path"]), tail)})
308
+ return
309
+
310
+ if method == "GET" and len(parts) == 3 and parts[2] == "manifest":
311
+ job = state.get_job_internal(job_id)
312
+ if job is None:
313
+ self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
314
+ else:
315
+ manifest = self._find_manifest(job)
316
+ if manifest is None:
317
+ self._send({"error": "manifest not found"}, HTTPStatus.NOT_FOUND)
318
+ else:
319
+ self._send(json.loads(manifest.read_text(encoding="utf-8")))
320
+ return
321
+
322
+ if method == "POST" and len(parts) == 3 and parts[2] == "cancel":
323
+ try:
324
+ self._send(state.cancel_job(job_id))
325
+ except KeyError:
326
+ self._send({"error": "job not found"}, HTTPStatus.NOT_FOUND)
327
+ return
328
+
329
+ self._send({"error": "not found"}, HTTPStatus.NOT_FOUND)
330
+ except Exception as exc:
331
+ traceback.print_exc()
332
+ self._send({"error": f"{type(exc).__name__}: {exc}"}, HTTPStatus.INTERNAL_SERVER_ERROR)
333
+
334
+ def _authorized(self) -> bool:
335
+ header = self.headers.get("Authorization", "")
336
+ if header == f"Bearer {token}":
337
+ return True
338
+ return self.headers.get("X-Colab-Token") == token
339
+
340
+ def _read_json(self) -> dict[str, Any]:
341
+ length = int(self.headers.get("Content-Length", "0"))
342
+ if length == 0:
343
+ return {}
344
+ raw = self.rfile.read(length)
345
+ return json.loads(raw.decode("utf-8"))
346
+
347
+ def _find_manifest(self, job: dict[str, Any]) -> Path | None:
348
+ config_path = job.get("config_path")
349
+ if config_path and Path(config_path).is_file():
350
+ config = json.loads(Path(config_path).read_text(encoding="utf-8"))
351
+ training = config.get("training", {})
352
+ save_dir = training.get("save_dir")
353
+ if save_dir:
354
+ manifest = Path(save_dir) / "colab_run_manifest.json"
355
+ if manifest.is_file():
356
+ return manifest
357
+ job_manifest = Path(job["job_dir"]) / "colab_run_manifest.json"
358
+ return job_manifest if job_manifest.is_file() else None
359
+
360
+ def _send(self, data: Any, status: HTTPStatus = HTTPStatus.OK) -> None:
361
+ raw = json_dumps(data).encode("utf-8")
362
+ self.send_response(status.value)
363
+ self.send_header("Content-Type", "application/json; charset=utf-8")
364
+ self.send_header("Content-Length", str(len(raw)))
365
+ self.end_headers()
366
+ self.wfile.write(raw)
367
+
368
+ return Handler
369
+
370
+
371
+ def start_tunnel(port: int, binary_path: Path) -> subprocess.Popen[str]:
372
+ cloudflared = download_cloudflared(binary_path)
373
+ cmd = [
374
+ os.fspath(cloudflared),
375
+ "tunnel",
376
+ "--url",
377
+ f"http://127.0.0.1:{port}",
378
+ "--no-autoupdate",
379
+ ]
380
+ proc = subprocess.Popen(
381
+ cmd,
382
+ stdout=subprocess.PIPE,
383
+ stderr=subprocess.STDOUT,
384
+ text=True,
385
+ encoding="utf-8",
386
+ errors="replace",
387
+ bufsize=1,
388
+ )
389
+
390
+ def pump() -> None:
391
+ assert proc.stdout is not None
392
+ for line in proc.stdout:
393
+ print(line, end="", flush=True)
394
+ match = TUNNEL_URL_RE.search(line)
395
+ if match:
396
+ print("\nCOLAB_WORKER_URL=" + match.group(0), flush=True)
397
+
398
+ threading.Thread(target=pump, daemon=True).start()
399
+ return proc
400
+
401
+
402
+ def parse_args() -> argparse.Namespace:
403
+ parser = argparse.ArgumentParser(description="Start the AniFileBERT Colab worker")
404
+ parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host")
405
+ parser.add_argument("--port", type=int, default=7860, help="HTTP bind port")
406
+ parser.add_argument("--repo-dir", default="/content/AniFileBERT", help="AniFileBERT checkout path in Colab")
407
+ parser.add_argument("--jobs-dir", default="/content/drive/MyDrive/AniFileBERT/worker/jobs")
408
+ parser.add_argument("--token", default=os.environ.get("ANIFILEBERT_COLAB_TOKEN"))
409
+ parser.add_argument("--tunnel", choices=["cloudflare", "none"], default="cloudflare")
410
+ parser.add_argument("--cloudflared-path", default="/tmp/anifilebert-cloudflared")
411
+ return parser.parse_args()
412
+
413
+
414
+ def main() -> None:
415
+ args = parse_args()
416
+ token = args.token or secrets.token_urlsafe(24)
417
+ repo_dir = Path(args.repo_dir)
418
+ if not repo_dir.is_dir():
419
+ raise RuntimeError(f"Repo directory does not exist: {repo_dir}")
420
+
421
+ state = WorkerState(repo_dir=repo_dir, jobs_dir=Path(args.jobs_dir))
422
+ server = ThreadingHTTPServer((args.host, args.port), make_handler(state, token))
423
+ tunnel_proc: subprocess.Popen[str] | None = None
424
+
425
+ print("=" * 72)
426
+ print("AniFileBERT Colab worker is starting")
427
+ print(f"Local URL: http://{args.host}:{args.port}")
428
+ print(f"COLAB_WORKER_TOKEN={token}")
429
+ print("Keep this Colab cell running while Codex uses the worker.")
430
+ print("=" * 72, flush=True)
431
+
432
+ if args.tunnel == "cloudflare":
433
+ tunnel_proc = start_tunnel(args.port, Path(args.cloudflared_path))
434
+ else:
435
+ print("Tunnel disabled. Use the local URL from inside the Colab runtime.", flush=True)
436
+
437
+ try:
438
+ server.serve_forever()
439
+ finally:
440
+ server.server_close()
441
+ if tunnel_proc and tunnel_proc.poll() is None:
442
+ tunnel_proc.terminate()
443
+
444
+
445
+ if __name__ == "__main__":
446
+ main()
train.py CHANGED
@@ -27,7 +27,7 @@ from transformers import (
27
  from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
28
 
29
  from config import Config
30
- from tokenizer import AnimeTokenizer, create_tokenizer
31
  from model import create_model, print_model_summary, count_parameters
32
  from dataset import AnimeDataset, align_tokens_for_tokenizer
33
 
@@ -64,8 +64,8 @@ def compute_metrics(p):
64
 
65
  def parse_args() -> argparse.Namespace:
66
  parser = argparse.ArgumentParser(description="Train anime filename parser")
67
- parser.add_argument("--tokenizer", choices=["regex", "char"], default="regex",
68
- help="Tokenizer variant for A/B testing")
69
  parser.add_argument("--data-file", default=None, help="Training JSONL file")
70
  parser.add_argument("--vocab-file", default=None,
71
  help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
@@ -84,11 +84,58 @@ def parse_args() -> argparse.Namespace:
84
  help="Rebuild vocab from the selected data file before training")
85
  parser.add_argument("--max-vocab-size", type=int, default=None,
86
  help="Optional vocab cap used with --rebuild-vocab")
 
 
 
 
87
  parser.add_argument("--cpu", action="store_true", help="Force CPU training")
88
  parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
 
 
89
  return parser.parse_args()
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str:
93
  if explicit_path:
94
  return explicit_path
@@ -96,6 +143,79 @@ def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Op
96
  return os.path.join(os.path.dirname(data_file), name)
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str,
100
  max_size: Optional[int] = None) -> None:
101
  token_lists: List[List[str]] = []
@@ -115,9 +235,10 @@ def main():
115
  config = Config()
116
  if args.data_file is not None:
117
  config.data_file = args.data_file
 
118
  if args.save_dir is not None:
119
  config.save_dir = args.save_dir
120
- elif args.tokenizer == "char":
121
  config.save_dir = "./checkpoints_char"
122
  if args.epochs is not None:
123
  config.num_epochs = args.epochs
@@ -131,6 +252,8 @@ def main():
131
  config.train_split = args.train_split
132
  if args.max_seq_length is not None:
133
  config.max_seq_length = args.max_seq_length
 
 
134
 
135
  random.seed(args.seed)
136
  np.random.seed(args.seed)
@@ -143,18 +266,20 @@ def main():
143
  all_data = all_data[:args.limit_samples]
144
  if not args.no_shuffle:
145
  random.shuffle(all_data)
 
146
 
147
  # Load tokenizer
148
  print("Loading tokenizer...")
149
- vocab_path = resolve_vocab_path(config.data_file, args.tokenizer, args.vocab_file)
150
- tokenizer = create_tokenizer(args.tokenizer)
151
  if args.rebuild_vocab or not os.path.isfile(vocab_path):
152
  max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
153
- print(f" Building {args.tokenizer} vocab: {vocab_path} (max_size={max_vocab_size})")
154
  build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
155
- tokenizer = create_tokenizer(args.tokenizer, vocab_file=vocab_path)
156
- print(f" Variant: {args.tokenizer}")
157
  print(f" Vocab size: {tokenizer.vocab_size}")
 
158
 
159
  # Update config with actual vocab size
160
  config.vocab_size = tokenizer.vocab_size
@@ -163,9 +288,22 @@ def main():
163
  if args.init_model_dir:
164
  print(f"Loading model for fine-tuning: {args.init_model_dir}")
165
  model = BertForTokenClassification.from_pretrained(args.init_model_dir)
166
- if model.config.vocab_size != config.vocab_size:
167
- print(f" Resizing token embeddings: {model.config.vocab_size} -> {config.vocab_size}")
168
- model.resize_token_embeddings(config.vocab_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  model.config.num_labels = config.num_labels
170
  model.config.id2label = config.id2label
171
  model.config.label2id = config.label2id
@@ -212,6 +350,8 @@ def main():
212
  use_cpu = args.cpu or not torch.cuda.is_available()
213
  use_fp16 = not use_cpu
214
  print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
 
 
215
 
216
  # Training arguments
217
  training_args = TrainingArguments(
@@ -220,15 +360,16 @@ def main():
220
  per_device_train_batch_size=config.batch_size,
221
  per_device_eval_batch_size=config.batch_size,
222
  eval_strategy="epoch",
223
- save_strategy="epoch",
 
224
  logging_steps=config.log_interval,
225
  learning_rate=config.learning_rate,
226
  weight_decay=config.weight_decay,
227
  warmup_steps=config.warmup_steps,
228
  use_cpu=use_cpu,
229
  report_to="none",
230
- save_total_limit=2,
231
- load_best_model_at_end=True,
232
  metric_for_best_model="f1",
233
  greater_is_better=True,
234
  dataloader_num_workers=config.num_workers,
@@ -250,12 +391,19 @@ def main():
250
 
251
  # Train
252
  print("Starting training...")
253
- trainer.train()
 
 
 
 
 
 
 
254
 
255
  # Set proper label mappings in model config before saving
256
  model.config.id2label = config.id2label
257
  model.config.label2id = config.label2id
258
- model.config.tokenizer_variant = args.tokenizer
259
  model.config.max_seq_length = config.max_seq_length
260
 
261
  # Save final model
 
27
  from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
28
 
29
  from config import Config
30
+ from tokenizer import AnimeTokenizer, create_tokenizer, load_tokenizer
31
  from model import create_model, print_model_summary, count_parameters
32
  from dataset import AnimeDataset, align_tokens_for_tokenizer
33
 
 
64
 
65
  def parse_args() -> argparse.Namespace:
66
  parser = argparse.ArgumentParser(description="Train anime filename parser")
67
+ parser.add_argument("--tokenizer", choices=["regex", "char"], default=None,
68
+ help="Tokenizer variant for A/B testing. Defaults to dataset metadata")
69
  parser.add_argument("--data-file", default=None, help="Training JSONL file")
70
  parser.add_argument("--vocab-file", default=None,
71
  help="Tokenizer vocab JSON. Defaults to data/vocab.json or data/vocab.char.json")
 
84
  help="Rebuild vocab from the selected data file before training")
85
  parser.add_argument("--max-vocab-size", type=int, default=None,
86
  help="Optional vocab cap used with --rebuild-vocab")
87
+ parser.add_argument("--checkpoint-steps", type=int, default=None,
88
+ help="Save resumable checkpoints every N steps instead of only at epoch end")
89
+ parser.add_argument("--save-total-limit", type=int, default=2,
90
+ help="Maximum number of checkpoints to keep")
91
  parser.add_argument("--cpu", action="store_true", help="Force CPU training")
92
  parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle before train/eval split")
93
+ parser.add_argument("--resume-from-checkpoint", default=None,
94
+ help="Resume Trainer state from a checkpoint directory, or 'auto' for the latest checkpoint")
95
  return parser.parse_args()
96
 
97
 
98
+ def detect_tokenizer_variant(
99
+ data_file: str,
100
+ explicit_variant: Optional[str],
101
+ explicit_vocab_path: Optional[str],
102
+ sample_size: int = 256,
103
+ ) -> str:
104
+ """Infer tokenizer variant from CLI, dataset metadata, or vocab filename."""
105
+ if explicit_variant:
106
+ return explicit_variant
107
+
108
+ variants = set()
109
+ char_like = 0
110
+ inspected = 0
111
+ with open(data_file, "r", encoding="utf-8") as f:
112
+ for line in f:
113
+ if inspected >= sample_size:
114
+ break
115
+ line = line.strip()
116
+ if not line:
117
+ continue
118
+ item = json.loads(line)
119
+ inspected += 1
120
+ variant = item.get("tokenizer_variant")
121
+ if variant:
122
+ variants.add(variant)
123
+ tokens = item.get("tokens", [])
124
+ filename = item.get("filename")
125
+ if filename is not None and tokens == list(filename):
126
+ char_like += 1
127
+
128
+ if len(variants) == 1:
129
+ return next(iter(variants))
130
+ if len(variants) > 1:
131
+ raise ValueError(f"Mixed tokenizer_variant values in {data_file}: {sorted(variants)}")
132
+ if explicit_vocab_path and ".char" in os.path.basename(explicit_vocab_path).lower():
133
+ return "char"
134
+ if inspected and char_like / inspected >= 0.95:
135
+ return "char"
136
+ return "regex"
137
+
138
+
139
  def resolve_vocab_path(data_file: str, tokenizer_variant: str, explicit_path: Optional[str]) -> str:
140
  if explicit_path:
141
  return explicit_path
 
143
  return os.path.join(os.path.dirname(data_file), name)
144
 
145
 
146
+ def latest_checkpoint(save_dir: str) -> Optional[str]:
147
+ if not os.path.isdir(save_dir):
148
+ return None
149
+ checkpoints = []
150
+ for name in os.listdir(save_dir):
151
+ if not name.startswith("checkpoint-"):
152
+ continue
153
+ path = os.path.join(save_dir, name)
154
+ if not os.path.isdir(path):
155
+ continue
156
+ try:
157
+ step = int(name.split("-")[-1])
158
+ except ValueError:
159
+ continue
160
+ checkpoints.append((step, path))
161
+ if not checkpoints:
162
+ return None
163
+ return max(checkpoints)[1]
164
+
165
+
166
+ def validate_dataset_tokenizer_metadata(data: List[Dict], tokenizer_variant: str) -> None:
167
+ variants = {item.get("tokenizer_variant") for item in data if item.get("tokenizer_variant")}
168
+ if variants and variants != {tokenizer_variant}:
169
+ raise ValueError(
170
+ f"Dataset tokenizer_variant {sorted(variants)} does not match selected tokenizer "
171
+ f"'{tokenizer_variant}'. Pass --tokenizer explicitly only when this is intentional."
172
+ )
173
+
174
+
175
+ def remap_token_embeddings(
176
+ model: BertForTokenClassification,
177
+ old_vocab: Dict[str, int],
178
+ new_vocab: Dict[str, int],
179
+ pad_token_id: int,
180
+ ) -> int:
181
+ """
182
+ Replace the input embedding table for a changed vocabulary.
183
+
184
+ resize_token_embeddings() preserves rows by numeric ID, which is unsafe when
185
+ two tokenizers assign different tokens to the same ID. This remaps by token
186
+ string and randomly initializes tokens that do not exist in the old vocab.
187
+ """
188
+ old_embeddings = model.get_input_embeddings()
189
+ old_weight = old_embeddings.weight.data
190
+ embedding_dim = old_weight.shape[1]
191
+ new_embeddings = torch.nn.Embedding(
192
+ len(new_vocab),
193
+ embedding_dim,
194
+ padding_idx=pad_token_id,
195
+ device=old_weight.device,
196
+ dtype=old_weight.dtype,
197
+ )
198
+ torch.nn.init.normal_(
199
+ new_embeddings.weight,
200
+ mean=0.0,
201
+ std=getattr(model.config, "initializer_range", 0.02),
202
+ )
203
+ if pad_token_id is not None and 0 <= pad_token_id < len(new_vocab):
204
+ new_embeddings.weight.data[pad_token_id].zero_()
205
+
206
+ copied = 0
207
+ for token, new_id in new_vocab.items():
208
+ old_id = old_vocab.get(token)
209
+ if old_id is None or old_id >= old_weight.shape[0]:
210
+ continue
211
+ new_embeddings.weight.data[new_id].copy_(old_weight[old_id])
212
+ copied += 1
213
+
214
+ model.set_input_embeddings(new_embeddings)
215
+ model.config.vocab_size = len(new_vocab)
216
+ return copied
217
+
218
+
219
  def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_path: str,
220
  max_size: Optional[int] = None) -> None:
221
  token_lists: List[List[str]] = []
 
235
  config = Config()
236
  if args.data_file is not None:
237
  config.data_file = args.data_file
238
+ tokenizer_variant = detect_tokenizer_variant(config.data_file, args.tokenizer, args.vocab_file)
239
  if args.save_dir is not None:
240
  config.save_dir = args.save_dir
241
+ elif tokenizer_variant == "char":
242
  config.save_dir = "./checkpoints_char"
243
  if args.epochs is not None:
244
  config.num_epochs = args.epochs
 
252
  config.train_split = args.train_split
253
  if args.max_seq_length is not None:
254
  config.max_seq_length = args.max_seq_length
255
+ elif tokenizer_variant == "char":
256
+ config.max_seq_length = max(config.max_seq_length, 128)
257
 
258
  random.seed(args.seed)
259
  np.random.seed(args.seed)
 
266
  all_data = all_data[:args.limit_samples]
267
  if not args.no_shuffle:
268
  random.shuffle(all_data)
269
+ validate_dataset_tokenizer_metadata(all_data, tokenizer_variant)
270
 
271
  # Load tokenizer
272
  print("Loading tokenizer...")
273
+ vocab_path = resolve_vocab_path(config.data_file, tokenizer_variant, args.vocab_file)
274
+ tokenizer = create_tokenizer(tokenizer_variant)
275
  if args.rebuild_vocab or not os.path.isfile(vocab_path):
276
  max_vocab_size = args.max_vocab_size if args.max_vocab_size is not None else config.vocab_size
277
+ print(f" Building {tokenizer_variant} vocab: {vocab_path} (max_size={max_vocab_size})")
278
  build_vocab_from_data(all_data, tokenizer, vocab_path, max_size=max_vocab_size)
279
+ tokenizer = create_tokenizer(tokenizer_variant, vocab_file=vocab_path)
280
+ print(f" Variant: {tokenizer_variant}")
281
  print(f" Vocab size: {tokenizer.vocab_size}")
282
+ print(f" Max sequence length: {config.max_seq_length}")
283
 
284
  # Update config with actual vocab size
285
  config.vocab_size = tokenizer.vocab_size
 
288
  if args.init_model_dir:
289
  print(f"Loading model for fine-tuning: {args.init_model_dir}")
290
  model = BertForTokenClassification.from_pretrained(args.init_model_dir)
291
+ init_tokenizer = load_tokenizer(args.init_model_dir)
292
+ init_variant = getattr(init_tokenizer, "tokenizer_variant", None)
293
+ if init_variant != tokenizer_variant:
294
+ print(f" WARNING: tokenizer variant changes during fine-tune: {init_variant} -> {tokenizer_variant}")
295
+ print(" Token embeddings will be remapped by token string; unmatched tokens are newly initialized.")
296
+ if model.config.vocab_size != config.vocab_size or init_tokenizer.get_vocab() != tokenizer.get_vocab():
297
+ copied = remap_token_embeddings(
298
+ model=model,
299
+ old_vocab=init_tokenizer.get_vocab(),
300
+ new_vocab=tokenizer.get_vocab(),
301
+ pad_token_id=tokenizer.pad_token_id,
302
+ )
303
+ print(
304
+ f" Remapped token embeddings: copied {copied:,}/{config.vocab_size:,} "
305
+ f"tokens from init checkpoint"
306
+ )
307
  model.config.num_labels = config.num_labels
308
  model.config.id2label = config.id2label
309
  model.config.label2id = config.label2id
 
350
  use_cpu = args.cpu or not torch.cuda.is_available()
351
  use_fp16 = not use_cpu
352
  print(f" Device: {'CPU' if use_cpu else 'CUDA'}")
353
+ save_strategy = "steps" if args.checkpoint_steps else "epoch"
354
+ load_best_model_at_end = args.checkpoint_steps is None
355
 
356
  # Training arguments
357
  training_args = TrainingArguments(
 
360
  per_device_train_batch_size=config.batch_size,
361
  per_device_eval_batch_size=config.batch_size,
362
  eval_strategy="epoch",
363
+ save_strategy=save_strategy,
364
+ save_steps=args.checkpoint_steps,
365
  logging_steps=config.log_interval,
366
  learning_rate=config.learning_rate,
367
  weight_decay=config.weight_decay,
368
  warmup_steps=config.warmup_steps,
369
  use_cpu=use_cpu,
370
  report_to="none",
371
+ save_total_limit=args.save_total_limit,
372
+ load_best_model_at_end=load_best_model_at_end,
373
  metric_for_best_model="f1",
374
  greater_is_better=True,
375
  dataloader_num_workers=config.num_workers,
 
391
 
392
  # Train
393
  print("Starting training...")
394
+ resume_from_checkpoint = args.resume_from_checkpoint
395
+ if resume_from_checkpoint == "auto":
396
+ resume_from_checkpoint = latest_checkpoint(config.save_dir)
397
+ if resume_from_checkpoint:
398
+ print(f"Resuming from latest checkpoint: {resume_from_checkpoint}")
399
+ else:
400
+ print("No checkpoint found; starting a fresh training run.")
401
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
402
 
403
  # Set proper label mappings in model config before saving
404
  model.config.id2label = config.id2label
405
  model.config.label2id = config.label2id
406
+ model.config.tokenizer_variant = tokenizer_variant
407
  model.config.max_seq_length = config.max_seq_length
408
 
409
  # Save final model