|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import argparse |
|
|
import random |
|
|
from datetime import datetime |
|
|
from importlib import import_module |
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint |
|
|
|
|
|
from mecari.config.config import get_model_config, override_config, save_config |
|
|
from mecari.data.data_module import DataModule |
|
|
|
|
|
|
|
|
def set_seed(seed: int = 42, deterministic: bool = True) -> None: |
|
|
"""Set random seeds for reproducibility. |
|
|
|
|
|
Args: |
|
|
seed: Random seed value. |
|
|
deterministic: If True, enforce deterministic behavior (slower). |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = deterministic |
|
|
torch.backends.cudnn.benchmark = not deterministic |
|
|
pl.seed_everything(seed) |
|
|
|
|
|
|
|
|
def get_config_sections(config: dict) -> dict: |
|
|
"""Extract structured sections from a unified config dict.""" |
|
|
return { |
|
|
"model": config["model"], |
|
|
"training": config["training"], |
|
|
"features": config.get("features", {}), |
|
|
"edge": config.get("edge_features", {}), |
|
|
} |
|
|
|
|
|
|
|
|
def calculate_feature_dim(config: dict) -> int: |
|
|
"""Return feature dimension from config (lexical features by default).""" |
|
|
features_cfg = config.get("features", {}) |
|
|
|
|
|
lexical_dim = features_cfg.get("lexical_feature_dim", 100000) |
|
|
return lexical_dim |
|
|
|
|
|
|
|
|
def create_data_module(config: dict) -> DataModule: |
|
|
"""Create DataModule from config (lexical-only pipeline).""" |
|
|
features_cfg = config.get("features", {}) |
|
|
training_cfg = config["training"] |
|
|
edge_cfg = config.get("edge_features", {}) |
|
|
|
|
|
lexical_feature_dim = features_cfg.get("lexical_feature_dim", 100000) |
|
|
|
|
|
return DataModule( |
|
|
annotations_dir=training_cfg["annotations_dir"], |
|
|
batch_size=training_cfg["batch_size"], |
|
|
num_workers=training_cfg["num_workers"], |
|
|
max_files=training_cfg.get("max_files"), |
|
|
use_bidirectional_edges=edge_cfg.get("use_bidirectional_edges", True), |
|
|
annotations_override_dir=training_cfg.get("annotations_override_dir"), |
|
|
lexical_feature_dim=lexical_feature_dim, |
|
|
) |
|
|
|
|
|
|
|
|
def setup_loggers(config: dict, experiment_name: str): |
|
|
"""Configure optional loggers (e.g., Weights & Biases).""" |
|
|
import subprocess |
|
|
|
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
|
loggers = [] |
|
|
|
|
|
if config["training"]["use_wandb"]: |
|
|
try: |
|
|
tags = [] |
|
|
try: |
|
|
branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"], text=True).strip() |
|
|
tags.append(f"branch:{branch}") |
|
|
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip() |
|
|
tags.append(f"commit:{commit}") |
|
|
except: |
|
|
pass |
|
|
|
|
|
wandb_logger = WandbLogger( |
|
|
project=config["training"]["project_name"], |
|
|
name=experiment_name, |
|
|
save_dir=f"experiments/{experiment_name}", |
|
|
save_code=True, |
|
|
log_model=False, |
|
|
tags=tags, |
|
|
) |
|
|
loggers.append(wandb_logger) |
|
|
print("✓ Added WandB logger (metrics only)") |
|
|
except Exception as e: |
|
|
print(f"WandbLogger initialization error: {e}") |
|
|
else: |
|
|
print("WandB logging disabled") |
|
|
|
|
|
if not loggers: |
|
|
loggers = False |
|
|
|
|
|
return loggers |
|
|
|
|
|
|
|
|
def create_trainer(config: dict, callbacks: list, loggers, deterministic: bool) -> pl.Trainer: |
|
|
"""Create a PyTorch Lightning Trainer.""" |
|
|
if torch.cuda.is_available(): |
|
|
accelerator = "gpu" |
|
|
devices = 1 |
|
|
else: |
|
|
accelerator = "cpu" |
|
|
devices = 1 |
|
|
|
|
|
max_steps = config["training"].get("max_steps", 8600) |
|
|
max_epochs = -1 |
|
|
|
|
|
trainer_kwargs = { |
|
|
"max_epochs": max_epochs, |
|
|
"max_steps": max_steps, |
|
|
"callbacks": callbacks, |
|
|
"logger": loggers, |
|
|
"accelerator": accelerator, |
|
|
"devices": devices, |
|
|
"log_every_n_steps": config["training"]["log_every_n_steps"], |
|
|
"val_check_interval": config["training"]["val_check_interval"], |
|
|
"gradient_clip_val": config["training"]["gradient_clip_val"], |
|
|
"enable_checkpointing": True, |
|
|
"enable_progress_bar": True, |
|
|
"limit_train_batches": 1.0, |
|
|
"limit_val_batches": 1.0, |
|
|
"limit_test_batches": 1.0, |
|
|
"limit_predict_batches": 1.0, |
|
|
"fast_dev_run": False, |
|
|
"deterministic": deterministic, |
|
|
"benchmark": not deterministic, |
|
|
"precision": "16-mixed", |
|
|
} |
|
|
|
|
|
if "gradient_clip_algorithm" in config["training"]: |
|
|
trainer_kwargs["gradient_clip_algorithm"] = config["training"]["gradient_clip_algorithm"] |
|
|
|
|
|
if "accumulate_grad_batches" in config["training"]: |
|
|
trainer_kwargs["accumulate_grad_batches"] = config["training"]["accumulate_grad_batches"] |
|
|
|
|
|
return pl.Trainer(**trainer_kwargs) |
|
|
|
|
|
|
|
|
def create_model_and_datamodule(config: dict, feature_dim: int, data_module: Optional[DataModule] = None): |
|
|
"""Create model and ensure DataModule is available (lexical-only).""" |
|
|
cfg = get_config_sections(config) |
|
|
model_cfg = cfg["model"] |
|
|
training_cfg = cfg["training"] |
|
|
features_cfg = cfg["features"] |
|
|
|
|
|
if data_module is None: |
|
|
data_module = create_data_module(config) |
|
|
|
|
|
common_params = { |
|
|
"hidden_dim": model_cfg["hidden_dim"], |
|
|
"num_classes": model_cfg["num_classes"], |
|
|
"learning_rate": training_cfg["learning_rate"], |
|
|
"lexical_feature_dim": features_cfg.get("lexical_feature_dim", 100000), |
|
|
} |
|
|
|
|
|
if model_cfg["type"] == "gatv2": |
|
|
MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2") |
|
|
model = MecariGATv2( |
|
|
**common_params, |
|
|
num_heads=model_cfg["num_heads"], |
|
|
share_weights=model_cfg.get("share_weights", False), |
|
|
dropout=model_cfg.get("dropout", 0.1), |
|
|
attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)), |
|
|
add_self_loops_flag=model_cfg.get("add_self_loops", True), |
|
|
edge_dropout=model_cfg.get("edge_dropout", 0.0), |
|
|
norm=model_cfg.get("norm", "layer"), |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported model type: {model_cfg['type']}") |
|
|
|
|
|
return model, data_module |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train the morphological analysis model") |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
"-m", |
|
|
choices=["gatv2"], |
|
|
default="gatv2", |
|
|
help="Model type (only gatv2 supported). If a config is provided, config.model.type takes precedence.", |
|
|
) |
|
|
parser.add_argument("--config", "-c", help="Path to config file (overrides model type if present)") |
|
|
parser.add_argument("--batch-size", "-b", type=int, help="Batch size") |
|
|
parser.add_argument("--steps", "-s", type=int, help="Max training steps") |
|
|
parser.add_argument("--lr", type=float, help="Learning rate") |
|
|
parser.add_argument("--hidden-dim", type=int, help="Hidden dimension size") |
|
|
parser.add_argument("--patience", type=int, help="Early stopping patience") |
|
|
parser.add_argument("--weight-decay", type=float, help="Weight decay") |
|
|
parser.add_argument("--no-wandb", action="store_true", help="Disable Weights & Biases logging") |
|
|
parser.add_argument("--seed", type=int, help="Random seed") |
|
|
parser.add_argument("--no-deterministic", action="store_true", help="Disable deterministic mode for speed") |
|
|
parser.add_argument("--resume", type=str, help="Experiment name to resume (e.g., gatv2_20250806_162945)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
from mecari.config.config import load_config |
|
|
|
|
|
config = load_config(args.config) |
|
|
if "model" in config and "type" in config["model"]: |
|
|
args.model = config["model"]["type"] |
|
|
else: |
|
|
config = get_model_config(args.model) |
|
|
|
|
|
overrides = {} |
|
|
|
|
|
|
|
|
training_overrides = {} |
|
|
if args.batch_size: |
|
|
training_overrides["batch_size"] = args.batch_size |
|
|
if args.steps: |
|
|
training_overrides["max_steps"] = args.steps |
|
|
if args.lr: |
|
|
training_overrides["learning_rate"] = args.lr |
|
|
if args.no_wandb: |
|
|
training_overrides["use_wandb"] = False |
|
|
if args.patience: |
|
|
training_overrides["patience"] = args.patience |
|
|
if args.seed: |
|
|
training_overrides["seed"] = args.seed |
|
|
if args.no_deterministic: |
|
|
training_overrides["deterministic"] = False |
|
|
|
|
|
if training_overrides: |
|
|
overrides["training"] = training_overrides |
|
|
|
|
|
|
|
|
if args.hidden_dim: |
|
|
overrides["model"] = {"hidden_dim": args.hidden_dim} |
|
|
|
|
|
|
|
|
if args.weight_decay: |
|
|
overrides.setdefault("training", {}) |
|
|
overrides["training"]["optimizer"] = {"weight_decay": args.weight_decay} |
|
|
|
|
|
if overrides: |
|
|
config = override_config(config, overrides) |
|
|
|
|
|
deterministic = config["training"].get("deterministic", True) |
|
|
set_seed(config["training"]["seed"], deterministic=deterministic) |
|
|
|
|
|
if not deterministic: |
|
|
print("⚡ Performance mode: deterministic=False (reproducibility not guaranteed)") |
|
|
|
|
|
resume_from_checkpoint = None |
|
|
experiment_name = None |
|
|
if args.resume: |
|
|
experiment_path = os.path.join("experiments", args.resume) |
|
|
if os.path.exists(experiment_path): |
|
|
checkpoint_dir = os.path.join(experiment_path, "checkpoints") |
|
|
if os.path.exists(checkpoint_dir): |
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")] |
|
|
if checkpoints: |
|
|
checkpoints.sort() |
|
|
resume_from_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1]) |
|
|
print(f"Resuming training from: {resume_from_checkpoint}") |
|
|
experiment_name = args.resume |
|
|
|
|
|
config_path = os.path.join(experiment_path, "config.yaml") |
|
|
if os.path.exists(config_path): |
|
|
from mecari.config.config import load_config |
|
|
|
|
|
config = load_config(config_path) |
|
|
print(f"Restored config from: {config_path}") |
|
|
else: |
|
|
print(f"Warning: No checkpoints found in: {checkpoint_dir}") |
|
|
else: |
|
|
print(f"Warning: Checkpoint directory not found: {checkpoint_dir}") |
|
|
else: |
|
|
print(f"Warning: Experiment directory not found: {experiment_path}") |
|
|
else: |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
experiment_name = f"{config['model']['type']}_{timestamp}" |
|
|
|
|
|
print(f"Experiment: {experiment_name}") |
|
|
print(f"Model: {config['model']['type'].upper()}") |
|
|
print("Lexical features: enabled (default)") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
|
|
else: |
|
|
print("💻 Using CPU") |
|
|
|
|
|
data_module = create_data_module(config) |
|
|
|
|
|
feature_dim = calculate_feature_dim(config) |
|
|
|
|
|
model, _ = create_model_and_datamodule(config, feature_dim, data_module) |
|
|
|
|
|
|
|
|
model.training_config = config["training"] |
|
|
|
|
|
experiment_dir = f"experiments/{experiment_name}" |
|
|
if not args.resume: |
|
|
os.makedirs(experiment_dir, exist_ok=True) |
|
|
save_config(config, f"{experiment_dir}/config.yaml") |
|
|
|
|
|
checkpoint_callback_error = ModelCheckpoint( |
|
|
dirpath=f"experiments/{experiment_name}/checkpoints", |
|
|
filename=f"{config['model']['type']}-{{epoch:02d}}-{{val_error_epoch:.3f}}", |
|
|
monitor="val_error_epoch", |
|
|
mode="min", |
|
|
save_top_k=1, |
|
|
save_last=True, |
|
|
) |
|
|
|
|
|
early_stopping = EarlyStopping( |
|
|
monitor="val_error_epoch", mode="min", patience=config["training"]["patience"], verbose=True, strict=False |
|
|
) |
|
|
|
|
|
loggers = setup_loggers(config, experiment_name) |
|
|
|
|
|
callbacks = [checkpoint_callback_error, early_stopping] |
|
|
try: |
|
|
if loggers: |
|
|
lr_monitor = LearningRateMonitor(logging_interval="step") |
|
|
callbacks.append(lr_monitor) |
|
|
except Exception: |
|
|
pass |
|
|
trainer = create_trainer(config, callbacks, loggers, deterministic) |
|
|
|
|
|
print("Starting training...") |
|
|
|
|
|
try: |
|
|
if resume_from_checkpoint: |
|
|
trainer.fit(model, data_module, ckpt_path=resume_from_checkpoint) |
|
|
else: |
|
|
trainer.fit(model, data_module) |
|
|
training_status = "completed" |
|
|
|
|
|
if data_module.test_dataset: |
|
|
print("Evaluating on test data...") |
|
|
trainer.test(model, data_module) |
|
|
print("Training complete!") |
|
|
except KeyboardInterrupt: |
|
|
print("\nTraining interrupted...") |
|
|
training_status = "interrupted" |
|
|
except Exception as e: |
|
|
print(f"\nError during training: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
training_status = "error" |
|
|
|
|
|
print(f"Experiment: {experiment_name}") |
|
|
print(f"Experiment dir: experiments/{experiment_name}") |
|
|
|
|
|
print("\n=== Saved models ===") |
|
|
|
|
|
if checkpoint_callback_error.best_model_path: |
|
|
best_error = ( |
|
|
float(checkpoint_callback_error.best_model_score) |
|
|
if checkpoint_callback_error.best_model_score is not None |
|
|
else 1.0 |
|
|
) |
|
|
print(f" Best val_error: {best_error:.6f}") |
|
|
print(f" → {os.path.basename(checkpoint_callback_error.best_model_path)}") |
|
|
|
|
|
print(f"\nFinal epoch: {trainer.current_epoch}") |
|
|
print(f"Training status: {training_status}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|