ejschwartz commited on
Commit
dcd8edc
·
1 Parent(s): 5ac585f

try normalizing

Browse files
Files changed (2) hide show
  1. app.py +8 -2
  2. normalize.py +164 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  import json
8
  import spaces
9
  import torch
 
10
 
11
  from transformers import AutoTokenizer
12
  from modeling_nova import NovaTokenizer, NovaForCausalLM
@@ -28,7 +29,12 @@ examples = json.load(open("humaneval_decompile_nova_6.7b.json", "r"))
28
 
29
 
30
  @spaces.GPU
31
- def predict(type, normalized_asm, _c_source):
 
 
 
 
 
32
 
33
  prompt_before = f"# This is the assembly code with {type} optimization:\n<func0>:"
34
  asm = normalized_asm.strip()
@@ -77,7 +83,7 @@ demo = gr.Interface(
77
  fn=predict,
78
  inputs=[
79
  gr.Text(label="Optimization Type", value="O0"),
80
- gr.Text(label="Normalized Assembly Code"),
81
  gr.Text(label="Original C Code"),
82
  ],
83
  outputs=gr.Text(label="Raw Nova Output"),
 
7
  import json
8
  import spaces
9
  import torch
10
+ from normalize import normalize
11
 
12
  from transformers import AutoTokenizer
13
  from modeling_nova import NovaTokenizer, NovaForCausalLM
 
29
 
30
 
31
  @spaces.GPU
32
+ def predict(type, input_asm, _c_source):
33
+
34
+ if "<func0>:" not in input_asm:
35
+ normalized_asm = normalize(input_asm)
36
+ else:
37
+ normalized_asm = input_asm
38
 
39
  prompt_before = f"# This is the assembly code with {type} optimization:\n<func0>:"
40
  asm = normalized_asm.strip()
 
83
  fn=predict,
84
  inputs=[
85
  gr.Text(label="Optimization Type", value="O0"),
86
+ gr.Text(label="Assembly Code (Normalized or not)"),
87
  gr.Text(label="Original C Code"),
88
  ],
89
  outputs=gr.Text(label="Raw Nova Output"),
normalize.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+
5
+
6
+ def hex_to_decimal(matched):
7
+ return str(int(matched.group(), 16))
8
+
9
+
10
+ def normalize(asm):
11
+ asm = asm.strip().split('\n')[: 257]
12
+
13
+ asm_lst = []
14
+ addr2label = {}
15
+ func_cnt, label_cnt = 0, 0
16
+ for i, line in enumerate(asm):
17
+ if line.strip() == '' or 'file format elf64-x86-64' in line:
18
+ continue
19
+
20
+ if len(line.split('\t')) == 1 and line.endswith(':'):
21
+ func = line[line.index('<') + 1 : line.index('>')]
22
+ asm_lst.append([f'<func{func_cnt}>:'])
23
+ func_cnt += 1
24
+ else:
25
+ if len(line.split('\t')) < 2:
26
+ print(line)
27
+ label_cnt += 1
28
+ addr, content = line.split('\t', 1)
29
+
30
+ addr = addr[: -1]
31
+ addr2label[addr] = f'<label-{label_cnt}>'
32
+ asm_lst.append(
33
+ [content.strip(), f'<label-{label_cnt}>']
34
+ )
35
+
36
+ new_asm = ''
37
+ for i, item in enumerate(asm_lst):
38
+ if len(item) == 1:
39
+ new_asm += '\n' + item[0]
40
+ continue
41
+ content, label = item
42
+
43
+ if '<' in content and '>' in content:
44
+ content = content[: content.index('<')].strip()
45
+
46
+ if content.startswith('j') or content.startswith('loop') or content.startswith('call'):
47
+ if len(content.split()) == 2:
48
+ inst, addr = content.split()
49
+ if addr.startswith('0x'):
50
+ addr = addr[2:]
51
+ if addr not in addr2label:
52
+ content = inst + '\t' + '<unk>'
53
+ else:
54
+ content = inst + '\t' + addr2label[addr]
55
+ content = re.sub(r"0x([0-9A-Fa-f]+)", hex_to_decimal, content)
56
+ content = content.replace('%', '')
57
+ content = re.sub(r"([,(])|([),])", r' \1\2 ', content)
58
+ content = re.sub(r' +', ' ', content).strip()
59
+
60
+ new_asm += '\n' + content + '\t' + label
61
+ return new_asm
62
+
63
+
64
+ def normalize_anghabench():
65
+ wp = open(f'anghabench/anghabench-normalize.jsonl', 'w')
66
+ fail = 0
67
+ with open(f'anghabench/anghabench.jsonl', 'r') as fp:
68
+ L = fp.readlines()
69
+ for i, line in enumerate(L):
70
+ try:
71
+ item = json.loads(line)
72
+ for opt in item['output']:
73
+ item['output'][opt] = normalize(item['output'][opt])
74
+ except Exception as e:
75
+ fail += 1
76
+ continue
77
+ wp.write(json.dumps(item) + '\n')
78
+
79
+ if i % 1000 == 0:
80
+ print(f"{i}/{len(L)}, fail: {fail}")
81
+
82
+
83
+ def normalize_the_stack():
84
+ wp = open('the-stack/the-stack-normalize.jsonl', 'w')
85
+ fail = 0
86
+ with open('the-stack/the-stack.jsonl', 'r') as fp:
87
+ L = fp.readlines()
88
+ for i, line in enumerate(L):
89
+ if i % 1000 == 0:
90
+ print(f"{i}/{len(L)}, fail: {fail}")
91
+ try:
92
+ item = json.loads(line)
93
+ for opt in item['output']:
94
+ item['output'][opt] = normalize(item['output'][opt]).strip()
95
+ except Exception as e:
96
+ fail += 1
97
+ print(e)
98
+ continue
99
+ wp.write(json.dumps(item) + '\n')
100
+
101
+
102
+ def normalize_codeart():
103
+ for file in os.listdir('codeart/'):
104
+ L = []
105
+ with open(f'codeart/{file}', 'r') as fp:
106
+ for l in fp.readlines():
107
+ item = json.loads(l.strip())
108
+ item['normalized_asm'] = normalize(item['asm'])
109
+ L.append(item)
110
+
111
+ with open(f'codeart/{file}', 'w') as wp:
112
+ for l in L:
113
+ wp.write(json.dumps(l) + '\n')
114
+
115
+
116
+ def normalize_binarycorp(binary_corp_folder):
117
+ data = {}
118
+ for file in os.listdir(binary_corp_folder):
119
+ if '-O0-' in file:
120
+ proj = file[: file.index('-O0-')]
121
+ opt = 'O0'
122
+ elif '-O1-' in file:
123
+ proj = file[: file.index('-O1-')]
124
+ opt = 'O1'
125
+ elif '-O3-' in file:
126
+ proj = file[: file.index('-O3-')]
127
+ opt = 'O3'
128
+ else:
129
+ continue
130
+ if proj not in data:
131
+ data[proj] = {}
132
+ content = json.load(open(f'{binary_corp_folder}/{file}', 'r'))
133
+ for k, v in content.items():
134
+ func = v['name']
135
+ asm = v['assembly']
136
+ if func not in data[proj]:
137
+ data[proj][func] = {}
138
+ data[proj][func][opt] = normalize(asm)
139
+ print(len(data))
140
+
141
+ data_filter = {}
142
+ for proj in data:
143
+ data_filter[proj] = {}
144
+ for func in data[proj]:
145
+ if len(data[proj][func]) < 2 or 'O3' not in data[proj][func]:
146
+ continue
147
+ data_filter[proj][func] = data[proj][func]
148
+ if len(data_filter[proj]) == 0:
149
+ data_filter.pop(proj)
150
+ json.dump(data_filter, open('binarycorp/binarycorp.json', 'w'), indent=2)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ # training data
155
+ normalize_the_stack()
156
+ normalize_anghabench()
157
+
158
+ # fine-tuning data
159
+ # download BinaryCorp small_train.tar from https://cloud.vul337.team:8443/s/cxnH8DfZTADLKCs
160
+ # binary_corp_folder = ''
161
+ # normalize_binarycorp(binary_corp_folder)
162
+
163
+ # evaluation data
164
+ # normalize_codeart()