Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from transformers import RobertaTokenizer, BertTokenizer | |
| from tqdm import tqdm # for our progress bar | |
| from transformers import AdamW | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from models.spatial_bert_model import SpatialBertModel | |
| from models.spatial_bert_model import SpatialBertConfig | |
| from models.spatial_bert_model import SpatialBertForMaskedLM | |
| from datasets.osm_sample_loader import PbfMapDataset | |
| from transformers.models.bert.modeling_bert import BertForMaskedLM | |
| import numpy as np | |
| import argparse | |
| import pdb | |
| DEBUG = False | |
| def training(args): | |
| num_workers = args.num_workers | |
| batch_size = args.batch_size | |
| epochs = args.epochs | |
| lr = args.lr #1e-7 # 5e-5 | |
| save_interval = args.save_interval | |
| max_token_len = args.max_token_len | |
| distance_norm_factor = args.distance_norm_factor | |
| spatial_dist_fill=args.spatial_dist_fill | |
| with_type = args.with_type | |
| sep_between_neighbors = args.sep_between_neighbors | |
| freeze_backbone = args.freeze_backbone | |
| bert_option = args.bert_option | |
| if_no_spatial_distance = args.no_spatial_distance | |
| assert bert_option in ['bert-base','bert-large'] | |
| london_file_path = '../data/sql_output/osm-point-london.json' | |
| california_file_path = '../data/sql_output/osm-point-california.json' | |
| if args.model_save_dir is None: | |
| sep_pathstr = '_sep' if sep_between_neighbors else '_nosep' | |
| freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze' | |
| context_pathstr = '_nocontext' if if_no_spatial_distance else '_withcontext' | |
| model_save_dir = '/data2/zekun/spatial_bert_weights/mlm_mem_lr' + str("{:.0e}".format(lr)) + sep_pathstr + context_pathstr +'/'+bert_option+ freeze_pathstr + '_mlm_mem_london_california_bsize' + str(batch_size) | |
| if not os.path.isdir(model_save_dir): | |
| os.makedirs(model_save_dir) | |
| else: | |
| model_save_dir = args.model_save_dir | |
| print('model_save_dir', model_save_dir) | |
| print('\n') | |
| if bert_option == 'bert-base': | |
| bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased') | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance) | |
| elif bert_option == 'bert-large': | |
| bert_model = BertForMaskedLM.from_pretrained('bert-large-uncased') | |
| tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") | |
| config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) | |
| else: | |
| raise NotImplementedError | |
| model = SpatialBertForMaskedLM(config) | |
| model.load_state_dict(bert_model.state_dict() , strict = False) # load sentence position embedding weights as well | |
| if bert_option == 'bert-large' and freeze_backbone: | |
| print('freezing backbone weights') | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| for param in model.cls.parameters(): | |
| param.requires_grad = True | |
| for param in model.bert.encoder.layer[21].parameters(): | |
| param.requires_grad = True | |
| for param in model.bert.encoder.layer[22].parameters(): | |
| param.requires_grad = True | |
| for param in model.bert.encoder.layer[23].parameters(): | |
| param.requires_grad = True | |
| london_train_dataset = PbfMapDataset(data_file_path = london_file_path, | |
| tokenizer = tokenizer, | |
| max_token_len = max_token_len, | |
| distance_norm_factor = distance_norm_factor, | |
| spatial_dist_fill = spatial_dist_fill, | |
| with_type = with_type, | |
| sep_between_neighbors = sep_between_neighbors, | |
| label_encoder = None, | |
| mode = None) | |
| california_train_dataset = PbfMapDataset(data_file_path = california_file_path, | |
| tokenizer = tokenizer, | |
| max_token_len = max_token_len, | |
| distance_norm_factor = distance_norm_factor, | |
| spatial_dist_fill = spatial_dist_fill, | |
| with_type = with_type, | |
| sep_between_neighbors = sep_between_neighbors, | |
| label_encoder = None, | |
| mode = None) | |
| train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset]) | |
| if DEBUG: | |
| train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, | |
| shuffle=False, pin_memory=True, drop_last=True) | |
| else: | |
| train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers, | |
| shuffle=True, pin_memory=True, drop_last=True) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| model.to(device) | |
| model.train() | |
| # initialize optimizer | |
| optim = AdamW(model.parameters(), lr = lr) | |
| print('start training...') | |
| for epoch in range(epochs): | |
| # setup loop with TQDM and dataloader | |
| loop = tqdm(train_loader, leave=True) | |
| iter = 0 | |
| for batch in loop: | |
| # initialize calculated gradients (from prev step) | |
| optim.zero_grad() | |
| # pull all tensor batches required for training | |
| input_ids = batch['masked_input'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| position_list_x = batch['norm_lng_list'].to(device) | |
| position_list_y = batch['norm_lat_list'].to(device) | |
| sent_position_ids = batch['sent_position_ids'].to(device) | |
| labels = batch['pseudo_sentence'].to(device) | |
| outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids, | |
| position_list_x = position_list_x, position_list_y = position_list_y, labels = labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| optim.step() | |
| loop.set_description(f'Epoch {epoch}') | |
| loop.set_postfix({'loss':loss.item()}) | |
| if DEBUG: | |
| print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() ) | |
| iter += 1 | |
| if iter % save_interval == 0 or iter == loop.total: | |
| save_path = os.path.join(model_save_dir, 'mlm_mem_keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \ | |
| + '_' +str("{:.4f}".format(loss.item())) +'.pth' ) | |
| torch.save(model.state_dict(), save_path) | |
| print('saving model checkpoint to', save_path) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--num_workers', type=int, default=5) | |
| parser.add_argument('--batch_size', type=int, default=12) | |
| parser.add_argument('--epochs', type=int, default=10) | |
| parser.add_argument('--save_interval', type=int, default=2000) | |
| parser.add_argument('--max_token_len', type=int, default=300) | |
| parser.add_argument('--lr', type=float, default = 5e-5) | |
| parser.add_argument('--distance_norm_factor', type=float, default = 0.0001) | |
| parser.add_argument('--spatial_dist_fill', type=float, default = 20) | |
| parser.add_argument('--with_type', default=False, action='store_true') | |
| parser.add_argument('--sep_between_neighbors', default=False, action='store_true') | |
| parser.add_argument('--freeze_backbone', default=False, action='store_true') | |
| parser.add_argument('--no_spatial_distance', default=False, action='store_true') | |
| parser.add_argument('--bert_option', type=str, default='bert-base') | |
| parser.add_argument('--model_save_dir', type=str, default=None) | |
| args = parser.parse_args() | |
| print('\n') | |
| print(args) | |
| print('\n') | |
| # out_dir not None, and out_dir does not exist, then create out_dir | |
| if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir): | |
| os.makedirs(args.model_save_dir) | |
| training(args) | |
| if __name__ == '__main__': | |
| main() | |