File size: 23,790 Bytes
230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 230508d b103659 230508d 87b2fa6 fc9d7f7 230508d fc9d7f7 230508d 87b2fa6 230508d b103659 87b2fa6 230508d 87b2fa6 07c93ac 190085d 87b2fa6 230508d b103659 230508d fc9d7f7 230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 86ec60c 87b2fa6 230508d 87b2fa6 b103659 230508d fc9d7f7 230508d 2aee25d 230508d 07c93ac 230508d 87b2fa6 230508d 87b2fa6 230508d 87b2fa6 230508d b103659 fc9d7f7 230508d 87b2fa6 230508d fc9d7f7 230508d fc9d7f7 230508d 87b2fa6 fc9d7f7 230508d 190085d 230508d 190085d fc9d7f7 190085d 230508d 190085d 230508d 190085d 230508d fc9d7f7 230508d 190085d 230508d fc9d7f7 230508d fc9d7f7 230508d b103659 230508d 190085d 230508d 07c93ac 87b2fa6 07c93ac b103659 a050f72 b103659 a050f72 190085d 86ec60c 190085d 07c93ac 230508d 190085d 230508d 190085d 230508d 190085d 230508d 87b2fa6 230508d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 | #!/usr/bin/env python3
"""Train small, base, and large PAWN models simultaneously on shared data.
All three models see the exact same batches in the same order, eliminating
data generation overhead and ensuring comparable training conditions.
Usage:
uv run python scripts/train_all.py --local-checkpoints
uv run python scripts/train_all.py --hf-repo thomas-schweich/pawn-{variant}
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import signal
import sys
import time
from pathlib import Path
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from pawn.config import CLMConfig, TrainingConfig
from pawn.model import PAWNCLM, clm_loss
from pawn.data import CLMDataset, create_validation_set
from pawn.gpu import configure_gpu
from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
from pawn.logging import MetricsLogger, random_slug
# ---------------------------------------------------------------------------
# Per-model state
# ---------------------------------------------------------------------------
class ModelSlot:
"""Holds everything needed to train and checkpoint one model variant."""
def __init__(
self,
name: str,
model_cfg: CLMConfig,
train_cfg: TrainingConfig,
device: str,
hf_repo: str | None,
shm_checkpoints: bool = False,
slug: str = "",
):
self.name = name
self.slug = slug
self.model_cfg = model_cfg
self.train_cfg = train_cfg
self.device = device
self.hf_repo = hf_repo
self.shm_checkpoints = shm_checkpoints
self.model = PAWNCLM(model_cfg).to(device)
param_count = sum(p.numel() for p in self.model.parameters())
print(f" {name}: {param_count:,} params ({model_cfg.d_model}d/{model_cfg.n_layers}L)")
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=train_cfg.lr,
weight_decay=train_cfg.weight_decay,
)
from pawn.trainer import CosineWithWarmup
self.scheduler = CosineWithWarmup(
self.optimizer,
warmup_steps=train_cfg.warmup_steps,
total_steps=train_cfg.total_steps,
)
self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
# Logger (creates run directory)
self.logger = MetricsLogger(
train_cfg.log_dir, run_prefix="run", device=device,
slug=slug, suffix=name,
)
self.run_dir = str(self.logger.run_dir)
self.jsonl_path = str(self.logger.metrics_path)
# Checkpoint directory: /dev/shm if requested, else under run_dir
if shm_checkpoints:
self.checkpoint_dir = f"/dev/shm/pawn_checkpoints/{name}"
else:
self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
self.global_step = 0
self.best_val_step = 0
self.best_val_loss = float("inf")
self.patience_counter = 0
self.stopped = False
# Background HF push (one thread per slot, so pushes don't block training)
from concurrent.futures import ThreadPoolExecutor
self._hf_push_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"hf-{name}")
self._hf_push_future = None
self.logger.log_config(
model=model_cfg.__dict__,
training=train_cfg.__dict__,
param_count=param_count,
formulation="clm",
multi_model=True,
variant=name,
)
self.logger.write_config_json(
model=model_cfg.__dict__,
training=train_cfg.__dict__,
param_count=param_count,
formulation="clm",
multi_model=True,
variant=name,
)
def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Forward + backward. Returns raw GPU tensor metrics (no .item() sync)."""
self.model.train()
input_ids = batch["input_ids"].to(self.device)
targets = batch["targets"].to(self.device)
loss_mask = batch["loss_mask"].to(self.device)
with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
loss, metrics = self.model.forward_train(input_ids, loss_mask, targets)
self.scaler.scale(loss).backward()
return metrics
def optimizer_step(self) -> float:
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.train_cfg.max_grad_norm
).item()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
self.scheduler.step()
return grad_norm
def _unwrapped_model(self):
"""Return the unwrapped model (strips torch.compile wrapper)."""
m = self.model
while hasattr(m, '_orig_mod'):
m = m._orig_mod
return m
def save_checkpoint(self):
path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
save_pretrain_checkpoint(
path, self._unwrapped_model(), self.optimizer, self.scheduler, self.scaler,
self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
)
print(f" [{self.name}] Checkpoint saved: {path}")
if self.hf_repo and self.hf_branch:
self._push_to_hf_async(path, self.global_step)
def _push_to_hf_async(self, ckpt_path: str, step: int):
"""Push checkpoint to HuggingFace in a background thread."""
# Wait for any previous push to finish before starting a new one
if self._hf_push_future is not None:
self._hf_push_future.result() # raises if previous push failed
def _push():
try:
push_checkpoint_to_hf(
ckpt_path, self.hf_repo, self.hf_branch,
metrics_path=self.jsonl_path, step=step,
)
print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
# On /dev/shm, clean up old checkpoints after successful push.
# Keep the latest (just saved) and the best (for post-training evals).
if self.shm_checkpoints:
keep = {Path(ckpt_path).name, f"step_{self.best_val_step:08d}"}
for old in sorted(Path(self.checkpoint_dir).glob("step_*")):
if old.name not in keep:
shutil.rmtree(old, ignore_errors=True)
except Exception as e:
print(f" [{self.name}] WARNING: HF push failed: {e}")
self._hf_push_future = self._hf_push_pool.submit(_push)
def push_metrics_to_hf(self):
"""Push just metrics.jsonl to HF (lightweight, no checkpoint)."""
if not self.hf_repo or not self.hf_branch:
return
def _push_metrics():
try:
from huggingface_hub import HfApi
api = HfApi()
api.create_branch(self.hf_repo, repo_type="model",
branch=self.hf_branch, exist_ok=True)
api.upload_file(
path_or_fileobj=self.jsonl_path,
path_in_repo="metrics.jsonl",
repo_id=self.hf_repo,
repo_type="model",
revision=self.hf_branch,
commit_message=f"Metrics through step {self.global_step}",
)
except Exception as e:
print(f" [{self.name}] WARNING: metrics push failed: {e}")
# Fire and forget on the push pool (queued behind any checkpoint push)
self._hf_push_pool.submit(_push_metrics)
def wait_for_push(self):
"""Block until any in-flight HF push completes."""
if self._hf_push_future is not None:
self._hf_push_future.result()
self._hf_push_future = None
@torch.no_grad()
def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
self.model.eval()
n = val_data["input_ids"].shape[0]
batch_size = self.train_cfg.batch_size
total_metrics: dict[str, float] = {}
n_batches = 0
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
input_ids = val_data["input_ids"][start:end].to(self.device)
targets = val_data["targets"][start:end].to(self.device)
loss_mask = val_data["loss_mask"][start:end].to(self.device)
with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
logits, _ = self.model(input_ids, loss_mask)
_, metrics = clm_loss(logits, targets, loss_mask)
for k, v in metrics.items():
total_metrics[k] = total_metrics.get(k, 0.0) + v
n_batches += 1
return {f"val/{k}": v / n_batches for k, v in total_metrics.items()}
def close(self):
self.wait_for_push()
self._hf_push_pool.shutdown(wait=True)
self.logger.close()
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(description="Train small/base/large PAWN models simultaneously")
p.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
p.add_argument("--total-steps", type=int, default=100_000, help="Total training steps")
p.add_argument("--batch-size", type=int, default=256, help="Batch size (shared across models)")
p.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
p.add_argument("--log-dir", type=str, default="logs", help="Log directory")
p.add_argument("--log-interval", type=int, default=10)
p.add_argument("--eval-interval", type=int, default=500)
p.add_argument("--checkpoint-interval", type=int, default=5000)
p.add_argument("--discard-ply-limit", action="store_true")
p.add_argument("--patience", type=int, default=10,
help="Stop if no val loss improvement for N eval intervals (0=disabled)")
p.add_argument("--wandb", action="store_true")
ckpt_group = p.add_mutually_exclusive_group(required=True)
ckpt_group.add_argument("--hf-repo", type=str, default=None,
help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
ckpt_group.add_argument("--local-checkpoints", action="store_true")
p.add_argument("--shm-checkpoints", action="store_true",
help="Write checkpoints to /dev/shm (RAM-backed, instant writes). "
"Requires --hf-repo since /dev/shm is volatile.")
p.add_argument("--run-evals", action="store_true",
help="Run probes, diagnostics, and Lichess eval after training completes")
p.add_argument("--lichess-pgn", type=str, default=None,
help="Path to Lichess PGN file for eval (required with --run-evals)")
p.add_argument("--publish-results", action="store_true",
help="Push eval results to HuggingFace (requires --hf-repo and --run-evals)")
return p.parse_args()
def _run_post_training_evals(slots: list[ModelSlot], args):
"""Run probes, diagnostics, and Lichess eval on best checkpoint per variant."""
import tempfile
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
from pawn.eval_suite.corpus import generate_corpus, load_corpus
from pawn.eval_suite.diagnostics import extract_diagnostic_positions, evaluate_diagnostic_positions
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
for slot in slots:
print(f"\n--- Evaluating {slot.name} ---")
# Use tracked best val step (kept on /dev/shm if shm_checkpoints)
best_step = slot.best_val_step
best_loss = slot.best_val_loss
ckpt_path = os.path.join(slot.checkpoint_dir, f"step_{best_step:08d}")
if not os.path.isdir(ckpt_path):
# Fall back to latest
ckpts = sorted(Path(slot.checkpoint_dir).glob("step_*"))
ckpt_path = str(ckpts[-1]) if ckpts else None
if not ckpt_path:
print(f" No checkpoint found, skipping")
continue
print(f" Best checkpoint: {ckpt_path} (val_loss={best_loss:.4f})")
# Load model (unwrapped)
from pawn.checkpoint import load_backbone_weights
state_dict, _ = load_backbone_weights(ckpt_path)
model = PAWNCLM(slot.model_cfg).to(device)
model.load_state_dict(state_dict)
model.eval()
results = {}
# 1. Probes
print(" Running probes...")
train_data = extract_probe_data(2048, 256, seed=12345)
val_data = extract_probe_data(512, 256, seed=54321)
probe_results = train_all_probes(
model, train_data, val_data, device=device,
per_layer=True, n_epochs=20, verbose=True,
)
results["probes"] = probe_results
del train_data, val_data
# 2. Diagnostics
print(" Running diagnostics...")
with tempfile.TemporaryDirectory() as tmpdir:
corpus_path = generate_corpus(tmpdir, n_games=2048, max_ply=255, seed=99999, batch_size=2048)
corpus = load_corpus(corpus_path)
positions = extract_diagnostic_positions(corpus, min_per_category=200, max_per_category=1000)
diag_results = evaluate_diagnostic_positions(model, positions, corpus, device=device)
results["diagnostics"] = diag_results
# 3. Lichess eval (if PGN provided)
if args.lichess_pgn:
print(" Running Lichess eval...")
from pawn.eval_suite.lichess import prepare_lichess_corpus, evaluate_on_lichess
lichess_data = prepare_lichess_corpus(args.lichess_pgn, max_games_per_band=1000)
lichess_results = evaluate_on_lichess(model, lichess_data, device=device)
results["lichess"] = lichess_results
# Save results
results_path = os.path.join(slot.run_dir, "eval_results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f" Results saved: {results_path}")
# Publish to HF
if args.publish_results and slot.hf_repo and slot.hf_branch:
from huggingface_hub import HfApi
api = HfApi()
try:
api.upload_file(
path_or_fileobj=results_path,
path_in_repo="eval_results.json",
repo_id=slot.hf_repo,
repo_type="model",
revision=slot.hf_branch,
commit_message=f"Eval results (best step {best_step})",
)
print(f" Published to {slot.hf_repo}@{slot.hf_branch}")
except Exception as e:
print(f" WARNING: HF publish failed: {e}")
del model, state_dict
if torch.cuda.is_available():
torch.cuda.empty_cache()
def main():
args = parse_args()
if args.shm_checkpoints and not args.hf_repo:
print("ERROR: --shm-checkpoints requires --hf-repo (HF is the only durable store)")
sys.exit(1)
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
gpu_cfg = configure_gpu()
import pawn.model as model_module
if gpu_cfg.get("sdpa_backend"):
model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"]
# Build per-variant configs (shared training hyperparams, different model sizes)
variants = {
"small": CLMConfig.small(),
"base": CLMConfig.base(),
"large": CLMConfig.large(),
}
slug = random_slug()
print(f"=== Multi-Model Training [{slug}] ===")
print(f"Device: {device}")
print(f"Batch size: {args.batch_size}")
print(f"Total steps: {args.total_steps}")
if args.shm_checkpoints:
print("Checkpoints: /dev/shm (volatile, HF push is durable store)")
print()
# Linear LR scaling: lr = base_lr * (batch_size / base_batch_size)
base_batch_size = 256
base_lr = TrainingConfig.lr
scaled_lr = base_lr * (args.batch_size / base_batch_size)
print(f"LR: {scaled_lr:.2e} (scaled from {base_lr:.2e} for batch {args.batch_size})")
slots: list[ModelSlot] = []
for name, model_cfg in variants.items():
train_cfg = TrainingConfig()
train_cfg.lr = scaled_lr
train_cfg.total_steps = args.total_steps
train_cfg.batch_size = args.batch_size
train_cfg.num_workers = args.num_workers
train_cfg.device = device
train_cfg.log_dir = args.log_dir
train_cfg.log_interval = args.log_interval
train_cfg.eval_interval = args.eval_interval
train_cfg.checkpoint_interval = args.checkpoint_interval
train_cfg.discard_ply_limit = args.discard_ply_limit
train_cfg.use_wandb = args.wandb
hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo,
shm_checkpoints=args.shm_checkpoints, slug=slug))
# Shared dataset and validation set
max_ply = 256
dataset = CLMDataset(
args.batch_size, max_ply, base_seed=42,
discard_ply_limit=args.discard_ply_limit,
)
print("\nGenerating shared validation set...")
val_data = create_validation_set(512, max_ply, seed=(2**63) - 1,
discard_ply_limit=args.discard_ply_limit)
# Compile models
if device != "cpu":
for slot in slots:
try:
slot.model = torch.compile(slot.model, mode="default")
print(f" [{slot.name}] torch.compile enabled")
except Exception:
print(f" [{slot.name}] torch.compile not available")
loader = DataLoader(
dataset,
batch_size=None,
num_workers=args.num_workers,
pin_memory=(device != "cpu"),
persistent_workers=(args.num_workers > 0),
prefetch_factor=1 if args.num_workers > 0 else None,
)
# Signal handling
_shutdown_requested = False
_shutdown_signal = None
def _graceful_exit(signum, frame):
nonlocal _shutdown_requested, _shutdown_signal
_shutdown_requested = True
_shutdown_signal = signum
signal.signal(signal.SIGTERM, _graceful_exit)
signal.signal(signal.SIGINT, _graceful_exit)
# Training loop
global_step = 0
step_start = time.time()
print(f"\nStarting training from step 0", flush=True)
for slot in slots:
print(f" [{slot.name}] JSONL: {slot.jsonl_path}", flush=True)
print()
active_slots = list(slots) # slots still training
for batch in loader:
# Forward + backward for active models only (no .item() sync)
all_metrics: dict[str, dict[str, torch.Tensor]] = {}
for slot in active_slots:
metrics = slot.train_step(batch)
all_metrics[slot.name] = metrics
# Optimizer step for active models
all_grad_norms: dict[str, float] = {}
for slot in active_slots:
gn = slot.optimizer_step()
all_grad_norms[slot.name] = gn
global_step += 1
for slot in slots:
slot.global_step = global_step
step_time = time.time() - step_start
games_per_sec = args.batch_size / step_time
# Logging — .item() sync only at log intervals
if global_step % args.log_interval == 0:
active_names = ", ".join(s.name for s in active_slots)
print(f"step {global_step:>7d} | {games_per_sec:.0f} g/s | {step_time:.2f}s | active: {active_names}", flush=True)
for slot in active_slots:
m = all_metrics[slot.name]
loss_val = m['loss'].item()
acc_val = m['accuracy'].item()
gn = all_grad_norms[slot.name]
lr = slot.scheduler.get_lr()
print(f" {slot.name:>5s}: loss {loss_val:.4f} | acc {acc_val:.3f} | "
f"lr {lr:.2e} | gn {gn:.2f}", flush=True)
slot.logger.log_train(
step=global_step,
lr=lr, grad_norm=gn,
step_time=step_time, games_per_sec=games_per_sec,
**{"train/loss": loss_val, "train/accuracy": acc_val},
)
# Eval
if global_step % args.eval_interval == 0:
for slot in active_slots:
val_metrics = slot.evaluate(val_data)
print(f" {slot.name:>5s} val: loss {val_metrics['val/loss']:.4f} | "
f"acc {val_metrics['val/accuracy']:.3f}", flush=True)
# Track best for eval, /dev/shm cleanup, and patience
vl = val_metrics["val/loss"]
if vl < slot.best_val_loss:
slot.best_val_loss = vl
slot.best_val_step = global_step
slot.patience_counter = 0
else:
slot.patience_counter += 1
slot.logger.log_val(
step=global_step,
patience=slot.patience_counter,
best_val_loss=slot.best_val_loss,
best_val_step=slot.best_val_step,
**val_metrics,
)
# Per-model early stopping
if args.patience > 0 and slot.patience_counter >= args.patience:
print(f" [{slot.name}] Early stopping — no improvement "
f"for {args.patience} evals (best step {slot.best_val_step})")
slot.stopped = True
slot.save_checkpoint()
active_slots = [s for s in active_slots if not s.stopped]
# Push metrics to HF after eval (lightweight, background)
for slot in slots:
slot.push_metrics_to_hf()
if not active_slots:
print(f"\nAll models stopped at step {global_step}")
break
# Checkpoint
if global_step % args.checkpoint_interval == 0:
for slot in active_slots:
slot.save_checkpoint()
# Done?
if global_step >= args.total_steps:
print(f"\nTraining complete at step {global_step}")
for slot in active_slots:
slot.save_checkpoint()
break
# Graceful shutdown
if _shutdown_requested:
print(f"\nShutdown requested (signal {_shutdown_signal}), "
f"saving checkpoints at step {global_step}...")
for slot in active_slots:
slot.save_checkpoint()
break
step_start = time.time()
# Cleanup
for slot in slots:
slot.close()
# Post-training evals
if args.run_evals:
print("\n" + "=" * 60)
print("POST-TRAINING EVALUATION")
print("=" * 60)
_run_post_training_evals(slots, args)
print("\nAll done.")
if __name__ == "__main__":
try:
mp.set_start_method("forkserver", force=True)
except ValueError:
mp.set_start_method("spawn", force=True)
main()
|