local _0428_base = import 'nl2code-base.libsonnet'; local _data_path = 'data/spider-en/'; local _output_from = true; local _fs = 2; function(args) _0428_base(output_from=_output_from, data_path=_data_path) + { local lr_s = '%0.1e' % args.lr, local bert_lr_s = '%0.1e' % args.bert_lr, local end_lr_s = if args.end_lr == 0 then '0e0' else '%0.1e' % args.end_lr, local base_bert_enc_size = 1024, local enc_size = base_bert_enc_size, model_name: 'bs=%(bs)d,lr=%(lr)s,bert_lr=%(bert_lr)s,end_lr=%(end_lr)s,att=%(att)d' % (args + { lr: lr_s, bert_lr: bert_lr_s, end_lr: end_lr_s, }), model+: { encoder+: { name: 'spider-bart', batch_encs_update:: null, question_encoder:: null, column_encoder:: null, table_encoder:: null, dropout:: null, update_config+: { name: 'relational_transformer', num_layers: args.num_layers, num_heads: 8, sc_link: args.sc_link, cv_link: args.cv_link, }, summarize_header: args.summarize_header, use_column_type: args.use_column_type, bart_version: args.bart_version, pretrained_checkpoint: args.pretrained_checkpoint, top_k_learnable:: null, word_emb_size:: null, }, encoder_preproc+: { word_emb:: null, min_freq:: null, max_count:: null, db_path: _data_path + "database", compute_sc_link: args.sc_link, compute_cv_link: args.cv_link, fix_issue_16_primary_keys: true, bart_version: args.bart_version, pretrained_checkpoint: args.pretrained_checkpoint, count_tokens_in_word_emb_for_vocab:: null, save_path: _data_path + 'BART-large-nl2code-1115,output_from=%s,fs=%d,emb=bart,cvlink' % [_output_from, _fs], }, decoder_preproc+: { grammar+: { end_with_from: args.end_with_from, clause_order: args.clause_order, infer_from_conditions: true, factorize_sketch: _fs, }, save_path: _data_path + 'BART-large-nl2code-1115,output_from=%s,fs=%d,emb=bart,cvlink' % [_output_from, _fs], compute_sc_link:: null, compute_cv_link:: null, db_path:: null, fix_issue_16_primary_keys:: null, bart_version:: null, pretrained_checkpoint:: null, }, decoder+: { name: 'NL2Code', dropout: 0.20687225956012834, desc_attn: 'mha', enc_recurrent_size: enc_size, recurrent_size : args.decoder_hidden_size, loss_type: 'softmax', use_align_mat: args.use_align_mat, use_align_loss: args.use_align_loss, }, }, train+: { batch_size: args.bs, num_batch_accumulated: args.num_batch_accumulated, clip_grad: 1, model_seed: args.att, data_seed: args.att, init_seed: args.att, }, optimizer: { name: 'bertAdamw', lr: 0.0, bert_lr: 0.0, }, lr_scheduler+: { name: 'bert_warmup_polynomial_group', start_lrs: [args.lr, args.bert_lr], end_lr: args.end_lr, num_warmup_steps: $.train.max_steps / 8, }, }