| |
| |
|
|
| |
|
|
|
|
| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.nn import init, MarginRankingLoss |
| from torch.optim import Adam |
| from distutils.version import LooseVersion |
| from torch.utils.data import Dataset, DataLoader |
| from torch.autograd import Variable |
| import math |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
| import nltk |
| import re |
| import torch.optim as optim |
| from tqdm import tqdm |
| from transformers import AutoModelForMaskedLM |
| import torch.nn.functional as F |
| import random |
|
|
|
|
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
|
|
| def greet(X, ny): |
| global eng_dict |
| ny = int(ny) |
| if ny == 0: |
| rand_no = random.random() |
| tok_map = {2: 0.4363429005892416, |
| 1: 0.6672580202327398, |
| 4: 0.7476060740459144, |
| 3: 0.9618703668504087, |
| 6: 0.9701028532809564, |
| 7: 0.9729244545819342, |
| 8: 0.9739508754144756, |
| 5: 0.9994508859743607, |
| 9: 0.9997507867114407, |
| 10: 0.9999112969650892, |
| 11: 0.9999788802297832, |
| 0: 0.9999831041838266, |
| 12: 0.9999873281378701, |
| 22: 0.9999957760459568, |
| 14: 1.0000000000000002} |
| for key in tok_map.keys(): |
| if rand_no < tok_map[key]: |
| num_sub_tokens_label = key |
| break |
| else: |
| num_sub_tokens_label = ny |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") |
| model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base") |
| model.load_state_dict(torch.load('model_26_2')) |
| model.eval() |
| X_init = X |
| X_init = X_init.replace("[MASK]", " [MASK] ") |
| X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label)) |
| tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt') |
| input_id_chunki = tokens['input_ids'][0].split(510) |
| input_id_chunks = [] |
| mask_chunks = [] |
| mask_chunki = tokens['attention_mask'][0].split(510) |
| for tensor in input_id_chunki: |
| input_id_chunks.append(tensor) |
| for tensor in mask_chunki: |
| mask_chunks.append(tensor) |
| xi = torch.full((1,), fill_value=101) |
| yi = torch.full((1,), fill_value=1) |
| zi = torch.full((1,), fill_value=102) |
| for r in range(len(input_id_chunks)): |
| input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1) |
| input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1) |
| mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1) |
| mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1) |
| di = torch.full((1,), fill_value=0) |
| for i in range(len(input_id_chunks)): |
| pad_len = 512 - input_id_chunks[i].shape[0] |
| if pad_len > 0: |
| for p in range(pad_len): |
| input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1) |
| mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1) |
| vb = torch.ones_like(input_id_chunks[0]) |
| fg = torch.zeros_like(input_id_chunks[0]) |
| maski = [] |
| for l in range(len(input_id_chunks)): |
| masked_pos = [] |
| for i in range(len(input_id_chunks[l])): |
| if input_id_chunks[l][i] == tokenizer.mask_token_id: |
| if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id: |
| continue |
| masked_pos.append(i) |
| maski.append(masked_pos) |
| input_ids = torch.stack(input_id_chunks) |
| att_mask = torch.stack(mask_chunks) |
| outputs = model(input_ids, attention_mask = att_mask) |
| last_hidden_state = outputs[0].squeeze() |
| l_o_l_sa = [] |
| sum_state = [] |
| for t in range(num_sub_tokens_label): |
| c = [] |
| l_o_l_sa.append(c) |
| if len(maski) == 1: |
| masked_pos = maski[0] |
| for k in masked_pos: |
| for t in range(num_sub_tokens_label): |
| l_o_l_sa[t].append(last_hidden_state[k+t]) |
| else: |
| for p in range(len(maski)): |
| masked_pos = maski[p] |
| for k in masked_pos: |
| for t in range(num_sub_tokens_label): |
| if (k+t) >= len(last_hidden_state[p]): |
| l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])]) |
| continue |
| l_o_l_sa[t].append(last_hidden_state[p][k+t]) |
| for t in range(num_sub_tokens_label): |
| sum_state.append(l_o_l_sa[t][0]) |
| for i in range(len(l_o_l_sa[0])): |
| if i == 0: |
| continue |
| for t in range(num_sub_tokens_label): |
| sum_state[t] = sum_state[t] + l_o_l_sa[t][i] |
| yip = len(l_o_l_sa[0]) |
| |
| er = "" |
| for t in range(num_sub_tokens_label): |
| sum_state[t] /= yip |
| idx = torch.topk(sum_state[t], k=5, dim=0)[1] |
| wor = [tokenizer.decode(i.item()).strip() for i in idx] |
| for kl in wor: |
| if all(char.isalpha() for char in kl): |
| |
| er+=kl |
| break |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return er |
| title = "Rename a variable in a Java class" |
| description = """This model is a fine-tuned GraphCodeBERT model fin-tuned to output higher-quality variable names for Java classes. Long classes are handled by the |
| model. Replace any variable name with a "[MASK]" to get an identifier renaming. |
| """ |
| ex = ["""import java.io.*; |
| public class x { |
| public static void main(String[] args) { |
| String f = "file.txt"; |
| BufferedReader [MASK] = null; |
| String l; |
| try { |
| [MASK] = new BufferedReader(new FileReader(f)); |
| while ((l = [MASK].readLine()) != null) { |
| System.out.println(l); |
| } |
| } catch (IOException e) { |
| e.printStackTrace(); |
| } finally { |
| try { |
| if ([MASK] != null) [MASK].close(); |
| } catch (IOException ex) { |
| ex.printStackTrace(); |
| } |
| } |
| } |
| }""", """import java.net.*; |
| import java.io.*; |
| |
| public class s { |
| public static void main(String[] args) throws IOException { |
| ServerSocket [MASK] = new ServerSocket(8000); |
| try { |
| Socket s = [MASK].accept(); |
| PrintWriter pw = new PrintWriter(s.getOutputStream(), true); |
| BufferedReader br = new BufferedReader(new InputStreamReader(s.getInputStream())); |
| String i; |
| while ((i = br.readLine()) != null) { |
| pw.println(i); |
| } |
| } finally { |
| if ([MASK] != null) [MASK].close(); |
| } |
| } |
| }""", """import java.io.*; |
| import java.util.*; |
| |
| public class y { |
| public static void main(String[] args) { |
| String [MASK] = "data.csv"; |
| String l = ""; |
| String cvsSplitBy = ","; |
| try (BufferedReader br = new BufferedReader(new FileReader([MASK]))) { |
| while ((l = br.readLine()) != null) { |
| String[] z = l.split(cvsSplitBy); |
| System.out.println("Values [field-1= " + z[0] + " , field-2=" + z[1] + "]"); |
| } |
| } catch (IOException e) { |
| e.printStackTrace(); |
| } |
| } |
| }"""] |
| |
| textbox = gr.Textbox(title=title, |
| description=description,examples = ex,label="Type Java code snippet:", placeholder="replace variable with [MASK]", lines=10) |
|
|
| gr.Interface(fn=greet, inputs=[ |
| textbox, |
| gr.Textbox(type="text", label="Number of tokens in name:", placeholder="0 for randomly sampled number of tokens") |
| ], outputs="text").launch() |
|
|
|
|
| |
|
|
|
|
| import java.io.*; |
| public class x { |
| public static void main(String[] args) { |
| String f = "file.txt"; |
| BufferedReader [MASK] = null; |
| String l; |
| try { |
| [MASK] = new BufferedReader(new FileReader(f)); |
| while ((l = [MASK].readLine()) != null) { |
| System.out.println(l); |
| } |
| } catch (IOException e) { |
| e.printStackTrace(); |
| } finally { |
| try { |
| if ([MASK] != null) [MASK].close(); |
| } catch (IOException ex) { |
| ex.printStackTrace(); |
| } |
| } |
| } |
| } |
|
|
|
|