Not-Grim-Refer commited on
Commit
54f7dc9
·
1 Parent(s): 78cc03b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -74
app.py CHANGED
@@ -1,81 +1,47 @@
1
- # Import necessary modules
2
  import gradio as gr
3
  import requests
4
- from transformers import AutoTokenizer, T5ForConditionalGeneration, T5Config
 
 
5
  import torch
6
 
7
- # Define maximum sequence length
8
  MAX_SOURCE_LENGTH = 512
9
 
10
- # Load tokenizer and model
11
- tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer")
12
- tokenizer.add_special_tokens({'additional_special_tokens': ['<e99>', '<e98>',..., '<e0>', '<msg>', '<add>', '<del>', '<keep>']})
13
- config = T5Config.from_pretrained("microsoft/codereviewer")
14
- model = T5ForConditionalGeneration.from_pretrained("microsoft/codereviewer", config=config)
15
- model.eval()
16
 
17
- def pad_to_max_length(source_ids):
18
- source_ids = source_ids[:MAX_SOURCE_LENGTH-2]
19
- source_ids = [tokenizer.bos_token_id] + source_ids + [tokenizer.eos_token_id]
20
- pad_len = MAX_SOURCE_LENGTH - len(source_ids)
21
- source_ids += [tokenizer.pad_token_id] * pad_len
22
- assert len(source_ids) == MAX_SOURCE_LENGTH
23
- return source_ids
24
-
25
- def encode_diff(diff, msg, source):
26
- lines = diff.split('\n')[1:]
27
- lines = [line for line in lines if line.strip()]
28
-
29
- labels = [0 if line[0] == '-' else 1 if line[0] == '+' else 2 for line in lines]
30
- lines = [line[1:].strip() for line in lines]
31
-
32
- tokens = [tokenizer.bos_token] + tokenizer.tokenize(source) + [tokenizer.eos_token]
33
- tokens += tokenizer.tokenize(msg)
34
- for label, line in zip(labels, lines):
35
- if label == 1:
36
- tokens += ['<add>'] + tokenizer.tokenize(line)
37
- elif label == 0:
38
- tokens += ['<del>'] + tokenizer.tokenize(line)
39
- else:
40
- tokens += ['<keep>'] + tokenizer.tokenize(line)
41
-
42
- return pad_to_max_length(tokenizer.convert_tokens_to_ids(tokens))
43
-
44
- def get_diffs_and_msg(user, repo, commit):
45
- commit_data = requests.get(f'https://api.github.com/repos/{user}/{repo}/commits/{commit}').json()
46
- msg = commit_data['commit']['message']
47
- diff_response = requests.get(f'https://api.github.com/repos/{user}/{repo}/commits/{commit}',
48
- headers={'Accept': 'application/vnd.github.diff'})
49
- diffs = diff_response.text
50
- return diffs, msg
51
-
52
- def generate_comments(user, repo, commit):
53
-
54
- diffs, msg = get_diffs_and_msg(user, repo, commit)
55
-
56
- file_diffs = []
57
- for diff in diffs.split('diff --git')[1:]:
58
- lines = diff.split('\n')
59
- file_name = lines[0].split(' a/')[1].split(' b/')[0]
60
- file_diffs.append({'name': file_name, 'diff': diff})
61
-
62
- output = ''
63
- for fd in file_diffs:
64
- source = requests.get(f'https://raw.githubusercontent.com/{user}/{repo}/{commit}/{fd["name"]}').text
65
- encoded = encode_diff(fd['diff'], msg, source)
66
- input_ids = torch.tensor([encoded]).to(model.device)
67
- attention_mask = input_ids.ne(tokenizer.pad_token_id).to(model.device)
68
-
69
- output_sequences = model.generate(
70
- input_ids=input_ids,
71
- attention_mask=attention_mask,
72
- max_length=100,
73
- num_beams=5,
74
- num_return_sequences=2,
75
- early_stopping=True
76
- )
77
-
78
- comments = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_sequences]
79
- output += f'File: {fd["name"]}\n{fd["diff"]}\n\nComments:\n{comments[0]}\n\n'
80
-
81
- return output
 
 
1
  import gradio as gr
2
  import requests
3
+ from torch import nn
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, T5Config
6
  import torch
7
 
 
8
  MAX_SOURCE_LENGTH = 512
9
 
 
 
 
 
 
 
10
 
11
+ class ReviewerModel(T5ForConditionalGeneration):
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.cls_head = nn.Linear(config.d_model, 2, bias=True)
16
+ # Fixed typo: config not self.config
17
+ self.init()
18
+
19
+ def init(self):
20
+ nn.init.xavier_uniform_(self.lm_head.weight)
21
+ factor = self.config.initializer_factor
22
+ self.cls_head.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
23
+ # Fixed exponentiation operator
24
+ self.cls_head.bias.data.zero_()
25
+
26
+ def forward(
27
+ self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels=None):
28
+
29
+ # Simplified method signature to only include necessary arguments
30
+
31
+ if labels is not None:
32
+ # Added validation check for seq2seq case
33
+
34
+ outputs = super().forward(
35
+ input_ids=input_ids,
36
+ attention_mask=attention_mask,
37
+ decoder_input_ids=decoder_input_ids,
38
+ decoder_attention_mask=decoder_attention_mask,
39
+ labels=labels
40
+ )
41
+ # Call super forward method with correct arguments
42
+
43
+ return outputs
44
+
45
+ # Removed unnecessary conditional logic
46
+ # Return super() forward directly for generation case
47
+