Spaces:
Running
Running
| import gradio as gr | |
| import requests | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer") | |
| tokenizer.special_dict = { | |
| f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1) | |
| } | |
| tokenizer.mask_id = tokenizer.get_vocab()["<mask>"] | |
| tokenizer.bos_id = tokenizer.get_vocab()["<s>"] | |
| tokenizer.pad_id = tokenizer.get_vocab()["<pad>"] | |
| tokenizer.eos_id = tokenizer.get_vocab()["</s>"] | |
| tokenizer.msg_id = tokenizer.get_vocab()["<msg>"] | |
| tokenizer.keep_id = tokenizer.get_vocab()["<keep>"] | |
| tokenizer.add_id = tokenizer.get_vocab()["<add>"] | |
| tokenizer.del_id = tokenizer.get_vocab()["<del>"] | |
| tokenizer.start_id = tokenizer.get_vocab()["<start>"] | |
| tokenizer.end_id = tokenizer.get_vocab()["<end>"] | |
| model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/codereviewer") | |
| model.eval() | |
| MAX_SOURCE_LENGTH = 512 | |
| def pad_assert(tokenizer, source_ids): | |
| source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
| source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
| pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
| source_ids += [tokenizer.pad_id] * pad_len | |
| assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
| return source_ids | |
| def encode_diff(tokenizer, diff, msg, source): | |
| difflines = diff.split("\n")[1:] # remove start @@ | |
| difflines = [line for line in difflines if len(line.strip()) > 0] | |
| map_dic = {"-": 0, "+": 1, " ": 2} | |
| def f(s): | |
| if s in map_dic: | |
| return map_dic[s] | |
| else: | |
| return 2 | |
| labels = [f(line[0]) for line in difflines] | |
| difflines = [line[1:].strip() for line in difflines] | |
| inputstr = "<s>" + source + "</s>" | |
| inputstr += "<msg>" + msg | |
| for label, line in zip(labels, difflines): | |
| if label == 1: | |
| inputstr += "<add>" + line | |
| elif label == 0: | |
| inputstr += "<del>" + line | |
| else: | |
| inputstr += "<keep>" + line | |
| source_ids = tokenizer.encode(inputstr, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
| source_ids = pad_assert(tokenizer, source_ids) | |
| return source_ids | |
| class FileDiffs(object): | |
| def __init__(self, diff_string): | |
| diff_array = diff_string.split("\n") | |
| self.file_name = diff_array[0] | |
| self.file_path = self.file_name.split("a/", 1)[1].rsplit("b/", 1)[0] | |
| self.diffs = list() | |
| for line in diff_array[4:]: | |
| if line.startswith("@@"): | |
| self.diffs.append(str()) | |
| self.diffs[-1] += "\n" + line | |
| def review_commit(user, repository, commit): | |
| commit_metadata = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}").json() | |
| msg = commit_metadata["commit"]["message"] | |
| diff_data = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}", headers={"Accept":"application/vnd.github.diff"}) | |
| code_diff = diff_data.text | |
| files_diffs = list() | |
| for file in code_diff.split("diff --git"): | |
| if len(file) > 0: | |
| fd = FileDiffs(file) | |
| files_diffs.append(fd) | |
| output = "" | |
| for fd in files_diffs: | |
| output += F"File:{fd.file_path}\n" | |
| source = requests.get(F"https://raw.githubusercontent.com/{user}/{repository}/^{commit}/{fd.file_path}").text | |
| for diff in fd.diffs: | |
| inputs = torch.tensor([encode_diff(tokenizer, diff, msg, source)], dtype=torch.long).to("cpu") | |
| inputs_mask = inputs.ne(tokenizer.pad_id) | |
| preds = model.generate(inputs, | |
| attention_mask=inputs_mask, | |
| use_cache=True, | |
| num_beams=5, | |
| early_stopping=True, | |
| max_length=100, | |
| num_return_sequences=2 | |
| ) | |
| preds = list(preds.cpu().numpy()) | |
| pred_nls = [tokenizer.decode(id[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in | |
| preds] | |
| output += diff + "\n#######\nComment:\n#######\n" + pred_nls[0] + "\n#######\n" | |
| return output | |
| iface = gr.Interface(fn=review_commit, inputs=["text", "text", "text"], outputs="text") | |
| iface.launch() | |