Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| # -*- coding:utf-8 -*- | |
| import os | |
| import logging | |
| from . import bleu | |
| from . import weighted_ngram_match | |
| from . import syntax_match | |
| from . import dataflow_match | |
| def calc_codebleu(predictions, references, lang, tokenizer=None, params='0.25,0.25,0.25,0.25', kw_dir = ".", langso_dir = "."): | |
| """_summary_ | |
| Args: | |
| predictions (list[str]): list of predictions | |
| references (list[str]): list of lists with references | |
| lang (str): ['java','js','c_sharp','php','go','python','ruby'] | |
| tokenizer (callable): tokenizer function, Defaults to lambda s: s.split() | |
| params (str, optional): Defaults to '0.25,0.25,0.25,0.25'. | |
| """ | |
| alpha, beta, gamma, theta = [float(x) for x in params.split(',')] | |
| # preprocess inputs | |
| references = [[x.strip() for x in ref] if type(ref) == list else [ref.strip()] for ref in references] | |
| hypothesis = [x.strip() for x in predictions] | |
| if not len(references) == len(hypothesis): | |
| raise ValueError | |
| # calculate ngram match (BLEU) | |
| if tokenizer is None: | |
| tokenizer = lambda s: s.split() | |
| tokenized_hyps = [tokenizer(x) for x in hypothesis] | |
| tokenized_refs = [[tokenizer(x) for x in reference] | |
| for reference in references] | |
| ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) | |
| # calculate weighted ngram match | |
| keywords = [x.strip() for x in open(kw_dir + '/keywords/' + lang + | |
| '.txt', 'r', encoding='utf-8').readlines()] | |
| def make_weights(reference_tokens, key_word_list): | |
| return {token: 1 if token in key_word_list else 0.2 | |
| for token in reference_tokens} | |
| tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] | |
| for reference_tokens in reference] for reference in tokenized_refs] | |
| weighted_ngram_match_score = weighted_ngram_match.corpus_bleu( | |
| tokenized_refs_with_weights, tokenized_hyps) | |
| # calculate syntax match | |
| syntax_match_score = syntax_match.corpus_syntax_match( | |
| references, hypothesis, lang, langso_dir) | |
| # calculate dataflow match | |
| dataflow_match_score = dataflow_match.corpus_dataflow_match( | |
| references, hypothesis, lang, langso_dir) | |
| # print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. | |
| # format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) | |
| code_bleu_score = alpha*ngram_match_score\ | |
| + beta*weighted_ngram_match_score\ | |
| + gamma*syntax_match_score\ | |
| + theta*(dataflow_match_score or 1) | |
| # print('CodeBLEU score: ', code_bleu_score) | |
| return { | |
| 'CodeBLEU': code_bleu_score, | |
| 'ngram_match_score': ngram_match_score, | |
| 'weighted_ngram_match_score': weighted_ngram_match_score, | |
| 'syntax_match_score': syntax_match_score, | |
| 'dataflow_match_score': dataflow_match_score | |
| } | |