WJAD / src /wjad /train /__init__.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
raw
history blame contribute delete
493 Bytes
"""训练相关:多任务损失合并、Trainer、调度器。"""
from .multitask import GradNormBalancer, PCGradCombiner, MultiTaskOptimizer
from .schedule import build_scheduler
from .trainer import Trainer, TrainerConfig, compute_all_losses, MAIN_TASK_KEYS, AUX_TASK_KEYS
__all__ = [
"GradNormBalancer",
"PCGradCombiner",
"MultiTaskOptimizer",
"build_scheduler",
"Trainer",
"TrainerConfig",
"compute_all_losses",
"MAIN_TASK_KEYS",
"AUX_TASK_KEYS",
]