unknown
commited on
Commit
·
ac3312e
1
Parent(s):
3e0611d
Initial
Browse files- Scripts/UnixCoder/model_gen.py +0 -31
- Scripts/UnixCoder/run_one_model.py +1 -101
- run_fine_tuning.sh +1 -1
Scripts/UnixCoder/model_gen.py
CHANGED
|
@@ -56,9 +56,6 @@ class Seq2Seq(nn.Module):
|
|
| 56 |
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
|
| 57 |
encoder_output = self.encoder(
|
| 58 |
source_ids, attention_mask=mask, use_cache=True)
|
| 59 |
-
# print("source_ids:", source_ids.size()) # torch.Size([56, 510])
|
| 60 |
-
# print("exist:", exist.size()) # torch.Size([56, 1])
|
| 61 |
-
# print("target_ids:", target_ids.size()) # torch.Size([56, 240])
|
| 62 |
ids = torch.cat((source_ids, target_ids), -1)
|
| 63 |
|
| 64 |
mask = self.bias[:,
|
|
@@ -68,33 +65,15 @@ class Seq2Seq(nn.Module):
|
|
| 68 |
out = self.decoder(target_ids, attention_mask=mask,
|
| 69 |
past_key_values=encoder_output.past_key_values).last_hidden_state
|
| 70 |
|
| 71 |
-
# 先concat 再池化
|
| 72 |
-
# print("out:", out.size()) # torch.Size([56, 240, 768])
|
| 73 |
-
|
| 74 |
lm_logits = self.lm_head(out[..., 1:, :])
|
| 75 |
-
# print("lm_logits:", lm_logits.size()) # torch.Size([56, 239, 51416])
|
| 76 |
-
|
| 77 |
# Shift so that tokens < n predict n
|
| 78 |
active_loss = target_ids[..., 2:].ne(1).view(-1)
|
| 79 |
-
# print("active_loss:", active_loss.size()) # torch.Size([13328])
|
| 80 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 81 |
-
# print("shift_logits:", shift_logits.size()) # torch.Size([56, 238, 51416])
|
| 82 |
-
|
| 83 |
shift_labels = target_ids[..., 2:].contiguous()
|
| 84 |
-
# print("shift_labels:", shift_labels.size()) # torch.Size([56, 238])
|
| 85 |
|
| 86 |
exist_labels = exist.contiguous()
|
| 87 |
-
# print("exist_labels:", exist_labels.size()) # torch.Size([56, 1])
|
| 88 |
-
|
| 89 |
-
# print("shift_logits.size:", shift_logits.size(-1)) # 51416
|
| 90 |
-
# print("shift_logits.view(-1, shift_logits.size(-1)):", shift_logits.view(-1, shift_logits.size(-1))[active_loss].size()) # torch.Size([614, 51416])
|
| 91 |
-
# print("shift_labels.view(-1):", shift_labels.view(-1)[active_loss].size()) # torch.Size([614])
|
| 92 |
-
|
| 93 |
pred_out = out[..., 0, :]
|
| 94 |
-
# print("pred_out:", pred_out.size()) # torch.Size([56, 768])
|
| 95 |
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
|
| 96 |
-
# print("pred_sigmoid:", pred_sigmoid.size()) # torch.Size([56, 1])
|
| 97 |
-
|
| 98 |
# Flatten the tokens
|
| 99 |
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
|
| 100 |
loss_fct_pred = nn.MSELoss(reduction="mean")
|
|
@@ -103,8 +82,6 @@ class Seq2Seq(nn.Module):
|
|
| 103 |
|
| 104 |
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
|
| 105 |
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
|
| 106 |
-
# loss = loss.to(torch.float32)
|
| 107 |
-
# loss = loss_pred
|
| 108 |
|
| 109 |
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
|
| 110 |
return outputs
|
|
@@ -135,10 +112,7 @@ class Seq2Seq(nn.Module):
|
|
| 135 |
mask = mask & ids[:, None, :].ne(1)
|
| 136 |
out = self.decoder(input_ids, attention_mask=mask,
|
| 137 |
past_key_values=context).last_hidden_state
|
| 138 |
-
# print("out:", out.size())
|
| 139 |
-
# concat 池化 out
|
| 140 |
hidden_states = out[:, -1, :]
|
| 141 |
-
# print("hidden_states:", hidden_states.size())
|
| 142 |
if out.size(1) == 1:
|
| 143 |
pred_sigmoid = self.sigmoid(self.pred_dense(
|
| 144 |
hidden_states.view(-1, 1, hidden_states.size(-1))))
|
|
@@ -155,14 +129,9 @@ class Seq2Seq(nn.Module):
|
|
| 155 |
pred = [torch.cat([x.view(-1) for x in p] + [zero] *
|
| 156 |
(self.max_length-len(p))).view(1, -1) for p in pred]
|
| 157 |
predicates.append(predicate[0][0])# ZM modified
|
| 158 |
-
#print("ZM-Model_Debug_P_Each_Itr: %d, %d, %d" % (len(predicate), len(predicate[0]), len(predicate[0][0])))
|
| 159 |
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
| 160 |
-
#print("ZM-Model_Debug_Predicate_Shape: %d" % (len(predicates)))
|
| 161 |
-
#print("ZM-Model_Debug_Codes_BeforeCat: %d, %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0]), len(preds[0][0][0])))
|
| 162 |
preds = torch.cat(preds, 0)
|
| 163 |
predicates = torch.tensor(predicates, device="cuda")# ZM modified
|
| 164 |
-
# predicates = torch.cat(predicates, 0).unsqueeze(0)
|
| 165 |
-
#print("ZM-Model_Debug_Codes_AfterCat: %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0])))
|
| 166 |
return preds, predicates
|
| 167 |
|
| 168 |
|
|
|
|
| 56 |
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
|
| 57 |
encoder_output = self.encoder(
|
| 58 |
source_ids, attention_mask=mask, use_cache=True)
|
|
|
|
|
|
|
|
|
|
| 59 |
ids = torch.cat((source_ids, target_ids), -1)
|
| 60 |
|
| 61 |
mask = self.bias[:,
|
|
|
|
| 65 |
out = self.decoder(target_ids, attention_mask=mask,
|
| 66 |
past_key_values=encoder_output.past_key_values).last_hidden_state
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
lm_logits = self.lm_head(out[..., 1:, :])
|
|
|
|
|
|
|
| 69 |
# Shift so that tokens < n predict n
|
| 70 |
active_loss = target_ids[..., 2:].ne(1).view(-1)
|
|
|
|
| 71 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
|
|
|
|
| 72 |
shift_labels = target_ids[..., 2:].contiguous()
|
|
|
|
| 73 |
|
| 74 |
exist_labels = exist.contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
pred_out = out[..., 0, :]
|
|
|
|
| 76 |
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
|
|
|
|
|
|
|
| 77 |
# Flatten the tokens
|
| 78 |
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
|
| 79 |
loss_fct_pred = nn.MSELoss(reduction="mean")
|
|
|
|
| 82 |
|
| 83 |
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
|
| 84 |
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
|
|
|
|
|
|
|
| 85 |
|
| 86 |
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
|
| 87 |
return outputs
|
|
|
|
| 112 |
mask = mask & ids[:, None, :].ne(1)
|
| 113 |
out = self.decoder(input_ids, attention_mask=mask,
|
| 114 |
past_key_values=context).last_hidden_state
|
|
|
|
|
|
|
| 115 |
hidden_states = out[:, -1, :]
|
|
|
|
| 116 |
if out.size(1) == 1:
|
| 117 |
pred_sigmoid = self.sigmoid(self.pred_dense(
|
| 118 |
hidden_states.view(-1, 1, hidden_states.size(-1))))
|
|
|
|
| 129 |
pred = [torch.cat([x.view(-1) for x in p] + [zero] *
|
| 130 |
(self.max_length-len(p))).view(1, -1) for p in pred]
|
| 131 |
predicates.append(predicate[0][0])# ZM modified
|
|
|
|
| 132 |
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
|
|
|
|
|
|
| 133 |
preds = torch.cat(preds, 0)
|
| 134 |
predicates = torch.tensor(predicates, device="cuda")# ZM modified
|
|
|
|
|
|
|
| 135 |
return preds, predicates
|
| 136 |
|
| 137 |
|
Scripts/UnixCoder/run_one_model.py
CHANGED
|
@@ -53,7 +53,6 @@ class Example(object):
|
|
| 53 |
vec,
|
| 54 |
exist,
|
| 55 |
module
|
| 56 |
-
# propertyposition,
|
| 57 |
):
|
| 58 |
self.idx = idx
|
| 59 |
self.source = source
|
|
@@ -77,8 +76,6 @@ def read_examples_no_bracket(filename, is_function_test):
|
|
| 77 |
break
|
| 78 |
line = line.strip()
|
| 79 |
js = json.loads(line)
|
| 80 |
-
if idx > 1000:
|
| 81 |
-
break
|
| 82 |
if js["Stmt"].strip()[0] == "}":
|
| 83 |
continue
|
| 84 |
if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
|
|
@@ -119,11 +116,6 @@ def read_examples_no_bracket(filename, is_function_test):
|
|
| 119 |
mod = ""
|
| 120 |
if "Module" in js.keys():
|
| 121 |
mod = js["Module"]
|
| 122 |
-
# propos = ' '.join(js['pp'])
|
| 123 |
-
# propos = ' '.join(propos.strip().split(','))
|
| 124 |
-
# print(code)
|
| 125 |
-
# print(nl)
|
| 126 |
-
# print(pro)
|
| 127 |
examples.append(
|
| 128 |
Example(
|
| 129 |
idx=idx,
|
|
@@ -152,8 +144,6 @@ def read_examples(filename, is_function_test):
|
|
| 152 |
break
|
| 153 |
line = line.strip()
|
| 154 |
js = json.loads(line)
|
| 155 |
-
if idx > 3000:
|
| 156 |
-
break
|
| 157 |
if 'idx' not in js:
|
| 158 |
js['idx'] = idx
|
| 159 |
code = ' '.join(js['FIR_token']).replace('\n', ' ')
|
|
@@ -188,11 +178,6 @@ def read_examples(filename, is_function_test):
|
|
| 188 |
mod = ""
|
| 189 |
if "Module" in js.keys():
|
| 190 |
mod = js["Module"]
|
| 191 |
-
# propos = ' '.join(js['pp'])
|
| 192 |
-
# propos = ' '.join(propos.strip().split(','))
|
| 193 |
-
# print(code)
|
| 194 |
-
# print(nl)
|
| 195 |
-
# print(pro)
|
| 196 |
examples.append(
|
| 197 |
Example(
|
| 198 |
idx=idx,
|
|
@@ -233,7 +218,7 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
| 233 |
# source
|
| 234 |
func_tokens = tokenizer.tokenize(example.funcname)
|
| 235 |
source_tokens = tokenizer.tokenize(
|
| 236 |
-
example.source)
|
| 237 |
pro_tokens = tokenizer.tokenize(example.property)
|
| 238 |
vec_tokens = example.vec
|
| 239 |
source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
|
|
@@ -243,8 +228,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
| 243 |
padding_length = args.max_source_length - len(source_ids)
|
| 244 |
source_ids += [tokenizer.pad_token_id] * padding_length
|
| 245 |
|
| 246 |
-
# target
|
| 247 |
-
# if stage=="test":
|
| 248 |
target_tokens = tokenizer.tokenize(example.target)
|
| 249 |
exist = [example.exist]
|
| 250 |
target_tokens = [tokenizer.cls_token, "<mask0>"] + \
|
|
@@ -252,13 +235,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
| 252 |
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
| 253 |
padding_length = args.max_target_length - len(target_ids)
|
| 254 |
target_ids += [tokenizer.pad_token_id] * padding_length
|
| 255 |
-
# else:
|
| 256 |
-
# target_tokens = tokenizer.tokenize(example.target)
|
| 257 |
-
# exist_tokens = tokenizer.tokenize(example.exist)
|
| 258 |
-
# target_tokens = ["<mask0>"] + exist_tokens + [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
|
| 259 |
-
# target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
| 260 |
-
# padding_length = args.max_target_length - len(target_ids)
|
| 261 |
-
# target_ids += [tokenizer.pad_token_id] * padding_length
|
| 262 |
|
| 263 |
features.append(
|
| 264 |
InputFeatures(
|
|
@@ -470,14 +446,7 @@ def vega_train_main():
|
|
| 470 |
total_eval_all = len(eval_examples_all)
|
| 471 |
patience, best_acc, losses, dev_dataset = 0, 0, [], {}
|
| 472 |
for epoch in tqdm(range(args.num_train_epochs)):
|
| 473 |
-
# print(args.num_train_epochs)
|
| 474 |
-
|
| 475 |
for idx, batch in enumerate(train_dataloader):
|
| 476 |
-
# print("##########Debug################")
|
| 477 |
-
# print(idx)
|
| 478 |
-
# print("###############Debug###########")
|
| 479 |
-
# if idx > 100:
|
| 480 |
-
# break
|
| 481 |
batch = tuple(t.to(device) for t in batch)
|
| 482 |
source_ids, exist, target_ids = batch
|
| 483 |
loss, _, _, mse_loss, ce_loss = model(
|
|
@@ -572,9 +541,7 @@ def vega_train_main():
|
|
| 572 |
# convert ids to text
|
| 573 |
for pred, predicate in zip(preds, predicates):
|
| 574 |
t = pred[0].cpu().numpy()
|
| 575 |
-
#p = predicate[0].cpu().numpy()
|
| 576 |
p = predicate.float().item()
|
| 577 |
-
#print("ZM_Debug -- ppp: " + str(p))
|
| 578 |
t = list(t)
|
| 579 |
#p = list(p)
|
| 580 |
tem_i = 0
|
|
@@ -608,7 +575,6 @@ def vega_train_main():
|
|
| 608 |
cnt_iteration += 1
|
| 609 |
pred = ref[0].strip()
|
| 610 |
predicate = ref[1]
|
| 611 |
-
#print("ZM_Debug -- predicate: " + str(predicate))
|
| 612 |
if gold.property.strip().lower() != "nothing":
|
| 613 |
predicate = 1.0
|
| 614 |
else:
|
|
@@ -626,7 +592,6 @@ def vega_train_main():
|
|
| 626 |
|
| 627 |
|
| 628 |
if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
|
| 629 |
-
#print("Total correct, Inside this place")
|
| 630 |
EM = EM + 1.0
|
| 631 |
EM_V = EM_V + 1.0
|
| 632 |
EM_P = EM_P + 1.0
|
|
@@ -646,43 +611,16 @@ def vega_train_main():
|
|
| 646 |
|
| 647 |
model_predicate.append(predicate)
|
| 648 |
groundtruth_predicate.append(gt_predicate)
|
| 649 |
-
# if len(pred.split(tokenizer.cls_token)) >= 2:
|
| 650 |
-
# if pred.split(tokenizer.cls_token)[0].strip() == gt_pred.split(tokenizer.cls_token)[0].strip():
|
| 651 |
-
# EM_P += 1
|
| 652 |
-
# if pred.split(tokenizer.cls_token)[1].strip() == gt_pred.split(tokenizer.cls_token)[1].strip():
|
| 653 |
-
# EM_V += 1
|
| 654 |
-
# MAE_P = mean_absolute_error(
|
| 655 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
| 656 |
-
# MSE_P = mean_squared_error(
|
| 657 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
| 658 |
-
# RMSE_P = np.sqrt(MSE_P)
|
| 659 |
dev_acc = round((100*EM/total), 2)
|
| 660 |
dev_acc_v = round((100*EM_V/total), 2)
|
| 661 |
dev_acc_p = round((100*EM_P/total), 2)
|
| 662 |
logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
|
| 663 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 664 |
logger.info(" "+"*"*20)
|
| 665 |
logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
|
| 666 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 667 |
logger.info(" "+"*"*20)
|
| 668 |
logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
|
| 669 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 670 |
logger.info(" "+"*"*20)
|
| 671 |
-
# logger.info(" %s = %s " %
|
| 672 |
-
# ("Current MAE P", str(round(MAE_P, 2))))
|
| 673 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 674 |
-
# logger.info(" "+"*"*20)
|
| 675 |
-
# logger.info(" %s = %s " %
|
| 676 |
-
# ("Current MSE P", str(round(MSE_P, 2))))
|
| 677 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 678 |
-
# logger.info(" "+"*"*20)
|
| 679 |
-
# logger.info(" %s = %s " %
|
| 680 |
-
# ("Current RMSE P", str(round(RMSE_P, 2))))
|
| 681 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
| 682 |
-
# logger.info(" "+"*"*20)
|
| 683 |
if dev_acc > best_acc:
|
| 684 |
-
#logger.info(" Best acc:%s",dev_acc)
|
| 685 |
-
#logger.info(" "+"*"*20)
|
| 686 |
best_acc = dev_acc
|
| 687 |
# Save best checkpoint for best bleu
|
| 688 |
output_dir = os.path.join(
|
|
@@ -694,15 +632,6 @@ def vega_train_main():
|
|
| 694 |
output_model_file = os.path.join(
|
| 695 |
output_dir, "pytorch_model.bin")
|
| 696 |
torch.save(model_to_save.state_dict(), output_model_file)
|
| 697 |
-
# with open(args.output_dir+"/p_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
| 698 |
-
# writer = csv.writer(fcsv2)
|
| 699 |
-
# for wl in p_wrong_list:
|
| 700 |
-
# writer.writerow(wl)
|
| 701 |
-
# with open(args.output_dir+"/v_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
| 702 |
-
# writer = csv.writer(fcsv2)
|
| 703 |
-
# for wl in v_wrong_list:
|
| 704 |
-
# writer.writerow(wl)
|
| 705 |
-
#print("ZM Debug--cnt_err_v: " + str(cnt_v))
|
| 706 |
logger.info(" Best acc:%s", best_acc)
|
| 707 |
logger.info(" " + "*" * 20)
|
| 708 |
|
|
@@ -753,9 +682,7 @@ def vega_train_main():
|
|
| 753 |
# convert ids to text
|
| 754 |
for pred, predicate in zip(preds, predicates):
|
| 755 |
t = pred[0].cpu().numpy()
|
| 756 |
-
#p = predicate[0].cpu().numpy()
|
| 757 |
p = predicate.float().item()
|
| 758 |
-
#print("ZM_Debug -- ppp: " + str(p))
|
| 759 |
t = list(t)
|
| 760 |
tem_i = 0
|
| 761 |
if 0 in t:
|
|
@@ -802,7 +729,6 @@ def vega_train_main():
|
|
| 802 |
predicate = 0.0
|
| 803 |
if 1 in gold.vec[-97:]:
|
| 804 |
predicate = 1.0
|
| 805 |
-
#my_cls = tokenizer.decode([tokenizer.cls_token_id],clean_up_tokenization_spaces=False)
|
| 806 |
gt_pred = gold.target.strip()
|
| 807 |
gt_predicate = gold.exist
|
| 808 |
is_re = False
|
|
@@ -840,30 +766,13 @@ def vega_train_main():
|
|
| 840 |
|
| 841 |
if pred == gt_pred:
|
| 842 |
EM_V += 1
|
| 843 |
-
# else:
|
| 844 |
-
# print("TEST Wrong pred:", pred, " gt_pred:", gt_pred)
|
| 845 |
if round(predicate) == gt_predicate:
|
| 846 |
EM_P += 1
|
| 847 |
model_predicate.append(predicate)
|
| 848 |
groundtruth_predicate.append(gt_predicate)
|
| 849 |
-
|
| 850 |
-
# MAE_P = mean_absolute_error(
|
| 851 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
| 852 |
-
# MSE_P = mean_squared_error(
|
| 853 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
| 854 |
-
# RMSE_P = np.sqrt(MSE_P)
|
| 855 |
-
|
| 856 |
dev_acc = round((100 * EM / total), 2)
|
| 857 |
dev_acc_v = round((100 * EM_V / total), 2)
|
| 858 |
dev_acc_p = round((100 * EM_P / total), 2)
|
| 859 |
-
# logger.info(" %s = %s " % ("Test Acc", str(dev_acc)))
|
| 860 |
-
# logger.info(" %s = %s " % ("Test Acc V", str(dev_acc_v)))
|
| 861 |
-
# logger.info(" %s = %s " % ("Test Acc P", str(dev_acc_p)))
|
| 862 |
-
# logger.info(" %s = %s "%("Test Edit sim",str(round(edit_sim/total, 2))))
|
| 863 |
-
# logger.info(" %s = %s " % ("Test MAE P", str(round(MAE_P, 2))))
|
| 864 |
-
# logger.info(" %s = %s " % ("Test MSE P", str(round(MSE_P, 2))))
|
| 865 |
-
# logger.info(" %s = %s " % ("Test RMSE P", str(round(RMSE_P, 2))))
|
| 866 |
-
# logger.info(" " + "*" * 20)
|
| 867 |
predictions = []
|
| 868 |
|
| 869 |
|
|
@@ -897,15 +806,6 @@ def vega_train_main():
|
|
| 897 |
json.dump(dic, f2)
|
| 898 |
f2.write('\n')
|
| 899 |
|
| 900 |
-
|
| 901 |
-
# with open(args.output_dir+"/p_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
| 902 |
-
# writer = csv.writer(fcsv2)
|
| 903 |
-
# for wl in p_wrong_list:
|
| 904 |
-
# writer.writerow(wl)
|
| 905 |
-
# with open(args.output_dir+"/v_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
| 906 |
-
# writer = csv.writer(fcsv2)
|
| 907 |
-
# for wl in v_wrong_list:
|
| 908 |
-
# writer.writerow(wl)
|
| 909 |
|
| 910 |
|
| 911 |
if __name__ == "__main__":
|
|
|
|
| 53 |
vec,
|
| 54 |
exist,
|
| 55 |
module
|
|
|
|
| 56 |
):
|
| 57 |
self.idx = idx
|
| 58 |
self.source = source
|
|
|
|
| 76 |
break
|
| 77 |
line = line.strip()
|
| 78 |
js = json.loads(line)
|
|
|
|
|
|
|
| 79 |
if js["Stmt"].strip()[0] == "}":
|
| 80 |
continue
|
| 81 |
if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
|
|
|
|
| 116 |
mod = ""
|
| 117 |
if "Module" in js.keys():
|
| 118 |
mod = js["Module"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
examples.append(
|
| 120 |
Example(
|
| 121 |
idx=idx,
|
|
|
|
| 144 |
break
|
| 145 |
line = line.strip()
|
| 146 |
js = json.loads(line)
|
|
|
|
|
|
|
| 147 |
if 'idx' not in js:
|
| 148 |
js['idx'] = idx
|
| 149 |
code = ' '.join(js['FIR_token']).replace('\n', ' ')
|
|
|
|
| 178 |
mod = ""
|
| 179 |
if "Module" in js.keys():
|
| 180 |
mod = js["Module"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
examples.append(
|
| 182 |
Example(
|
| 183 |
idx=idx,
|
|
|
|
| 218 |
# source
|
| 219 |
func_tokens = tokenizer.tokenize(example.funcname)
|
| 220 |
source_tokens = tokenizer.tokenize(
|
| 221 |
+
example.source)
|
| 222 |
pro_tokens = tokenizer.tokenize(example.property)
|
| 223 |
vec_tokens = example.vec
|
| 224 |
source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
|
|
|
|
| 228 |
padding_length = args.max_source_length - len(source_ids)
|
| 229 |
source_ids += [tokenizer.pad_token_id] * padding_length
|
| 230 |
|
|
|
|
|
|
|
| 231 |
target_tokens = tokenizer.tokenize(example.target)
|
| 232 |
exist = [example.exist]
|
| 233 |
target_tokens = [tokenizer.cls_token, "<mask0>"] + \
|
|
|
|
| 235 |
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
| 236 |
padding_length = args.max_target_length - len(target_ids)
|
| 237 |
target_ids += [tokenizer.pad_token_id] * padding_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
features.append(
|
| 240 |
InputFeatures(
|
|
|
|
| 446 |
total_eval_all = len(eval_examples_all)
|
| 447 |
patience, best_acc, losses, dev_dataset = 0, 0, [], {}
|
| 448 |
for epoch in tqdm(range(args.num_train_epochs)):
|
|
|
|
|
|
|
| 449 |
for idx, batch in enumerate(train_dataloader):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
batch = tuple(t.to(device) for t in batch)
|
| 451 |
source_ids, exist, target_ids = batch
|
| 452 |
loss, _, _, mse_loss, ce_loss = model(
|
|
|
|
| 541 |
# convert ids to text
|
| 542 |
for pred, predicate in zip(preds, predicates):
|
| 543 |
t = pred[0].cpu().numpy()
|
|
|
|
| 544 |
p = predicate.float().item()
|
|
|
|
| 545 |
t = list(t)
|
| 546 |
#p = list(p)
|
| 547 |
tem_i = 0
|
|
|
|
| 575 |
cnt_iteration += 1
|
| 576 |
pred = ref[0].strip()
|
| 577 |
predicate = ref[1]
|
|
|
|
| 578 |
if gold.property.strip().lower() != "nothing":
|
| 579 |
predicate = 1.0
|
| 580 |
else:
|
|
|
|
| 592 |
|
| 593 |
|
| 594 |
if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
|
|
|
|
| 595 |
EM = EM + 1.0
|
| 596 |
EM_V = EM_V + 1.0
|
| 597 |
EM_P = EM_P + 1.0
|
|
|
|
| 611 |
|
| 612 |
model_predicate.append(predicate)
|
| 613 |
groundtruth_predicate.append(gt_predicate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
dev_acc = round((100*EM/total), 2)
|
| 615 |
dev_acc_v = round((100*EM_V/total), 2)
|
| 616 |
dev_acc_p = round((100*EM_P/total), 2)
|
| 617 |
logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
|
|
|
|
| 618 |
logger.info(" "+"*"*20)
|
| 619 |
logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
|
|
|
|
| 620 |
logger.info(" "+"*"*20)
|
| 621 |
logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
|
|
|
|
| 622 |
logger.info(" "+"*"*20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
if dev_acc > best_acc:
|
|
|
|
|
|
|
| 624 |
best_acc = dev_acc
|
| 625 |
# Save best checkpoint for best bleu
|
| 626 |
output_dir = os.path.join(
|
|
|
|
| 632 |
output_model_file = os.path.join(
|
| 633 |
output_dir, "pytorch_model.bin")
|
| 634 |
torch.save(model_to_save.state_dict(), output_model_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
logger.info(" Best acc:%s", best_acc)
|
| 636 |
logger.info(" " + "*" * 20)
|
| 637 |
|
|
|
|
| 682 |
# convert ids to text
|
| 683 |
for pred, predicate in zip(preds, predicates):
|
| 684 |
t = pred[0].cpu().numpy()
|
|
|
|
| 685 |
p = predicate.float().item()
|
|
|
|
| 686 |
t = list(t)
|
| 687 |
tem_i = 0
|
| 688 |
if 0 in t:
|
|
|
|
| 729 |
predicate = 0.0
|
| 730 |
if 1 in gold.vec[-97:]:
|
| 731 |
predicate = 1.0
|
|
|
|
| 732 |
gt_pred = gold.target.strip()
|
| 733 |
gt_predicate = gold.exist
|
| 734 |
is_re = False
|
|
|
|
| 766 |
|
| 767 |
if pred == gt_pred:
|
| 768 |
EM_V += 1
|
|
|
|
|
|
|
| 769 |
if round(predicate) == gt_predicate:
|
| 770 |
EM_P += 1
|
| 771 |
model_predicate.append(predicate)
|
| 772 |
groundtruth_predicate.append(gt_predicate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
dev_acc = round((100 * EM / total), 2)
|
| 774 |
dev_acc_v = round((100 * EM_V / total), 2)
|
| 775 |
dev_acc_p = round((100 * EM_P / total), 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
predictions = []
|
| 777 |
|
| 778 |
|
|
|
|
| 806 |
json.dump(dic, f2)
|
| 807 |
f2.write('\n')
|
| 808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
|
| 810 |
|
| 811 |
if __name__ == "__main__":
|
run_fine_tuning.sh
CHANGED
|
@@ -10,6 +10,6 @@ python ./Scripts/UnixCoder/run_one_model.py \
|
|
| 10 |
--train_batch_size 64 \
|
| 11 |
--eval_batch_size 48 \
|
| 12 |
--learning_rate 6e-5 \
|
| 13 |
-
--num_train_epochs
|
| 14 |
--mse_loss_weight 0.9 \
|
| 15 |
--ce_loss_weight 0.1
|
|
|
|
| 10 |
--train_batch_size 64 \
|
| 11 |
--eval_batch_size 48 \
|
| 12 |
--learning_rate 6e-5 \
|
| 13 |
+
--num_train_epochs 50 \
|
| 14 |
--mse_loss_weight 0.9 \
|
| 15 |
--ce_loss_weight 0.1
|