| local env = import "../env.jsonnet"; |
| local base = import "basic.jsonnet"; |
|
|
| local debug = false; |
|
|
| # re-train |
| local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); |
| local rt_lr = env.json("RT_LR", 5e-5); |
|
|
| # module |
| local cuda_devices = base.cuda_devices; |
|
|
| { |
| dataset_reader: base.dataset_reader, |
| train_data_path: base.train_data_path, |
| validation_data_path: base.validation_data_path, |
| test_data_path: base.test_data_path, |
|
|
| datasets_for_vocab_creation: ["train"], |
|
|
| data_loader: base.data_loader, |
| validation_data_loader: base.validation_data_loader, |
|
|
| model: { |
| type: "span", |
| word_embedding: { |
| "_pretrained": { |
| "archive_file": pretrained_path, |
| "module_path": "word_embedding", |
| "freeze": false, |
| } |
| }, |
| span_extractor: { |
| "_pretrained": { |
| "archive_file": pretrained_path, |
| "module_path": "_span_extractor", |
| "freeze": false, |
| } |
| }, |
| span_finder: { |
| "_pretrained": { |
| "archive_file": pretrained_path, |
| "module_path": "_span_finder", |
| "freeze": false, |
| } |
| }, |
| span_typing: { |
| type: 'mlp', |
| hidden_dims: base.model.span_typing.hidden_dims, |
| }, |
| metrics: [{type: "srl"}], |
|
|
| typing_loss_factor: base.model.typing_loss_factor, |
| label_dim: base.model.label_dim, |
| max_decoding_spans: 128, |
| max_recursion_depth: 2, |
| debug: debug, |
| }, |
|
|
| trainer: { |
| num_epochs: base.trainer.num_epochs, |
| patience: base.trainer.patience, |
| [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], |
| validation_metric: "+arg-c_f", |
| num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, |
| optimizer: { |
| type: "transformer", |
| base: { |
| type: "adam", |
| lr: base.trainer.optimizer.base.lr, |
| }, |
| embeddings_lr: 0.0, |
| encoder_lr: 1e-5, |
| pooler_lr: 1e-5, |
| layer_fix: base.trainer.optimizer.layer_fix, |
| parameter_groups: [ |
| [['_span_finder.*'], {'lr': rt_lr}], |
| [['_span_extractor.*'], {'lr': rt_lr}], |
| ] |
| } |
| }, |
|
|
| [if std.length(cuda_devices) > 1 then "distributed"]: { |
| "cuda_devices": cuda_devices |
| }, |
| [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true |
| } |
|
|