piliguori commited on
Commit
be20041
·
verified ·
1 Parent(s): 7fa85f5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +399 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import sys
6
+ import argparse
7
+ import autopep8
8
+ import glob
9
+ import re
10
+
11
+ def normalize_indentation(code):
12
+ """
13
+ Normalize indentation in example code by removing excessive tabs.
14
+ Also removes any backslash characters.
15
+ """
16
+ # Remove backslash characters
17
+ code = code.replace('\\', '')
18
+
19
+ lines = code.split('\n')
20
+ if not lines:
21
+ return ""
22
+
23
+ # Check if we have a function with excessive indentation
24
+ fixed_lines = []
25
+ indent_fix_mode = False
26
+
27
+ for i, line in enumerate(lines):
28
+ if line.strip().startswith('def '):
29
+ fixed_lines.append(line)
30
+ indent_fix_mode = True
31
+ elif indent_fix_mode and line.strip():
32
+ # For indented lines in a function
33
+ if line.startswith('\t\t'): # Two tabs
34
+ fixed_lines.append('\t' + line[2:]) # Replace with one tab
35
+ elif line.startswith(' '): # 8 spaces (2 levels)
36
+ fixed_lines.append(' ' + line[8:]) # Replace with 4 spaces
37
+ else:
38
+ fixed_lines.append(line)
39
+ else:
40
+ fixed_lines.append(line)
41
+
42
+ return '\n'.join(fixed_lines)
43
+
44
+ def clear_text(text):
45
+ """
46
+ Cleans text from escape sequences while preserving original formatting.
47
+ """
48
+ # First, replace \\n with a special placeholder
49
+ temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
50
+ temp_tab = "TEMP_TAB_PLACEHOLDER"
51
+
52
+ # Temporarily replace escape sequences
53
+ text = text.replace("\\n", temp_newline)
54
+ text = text.replace("\\t", temp_tab)
55
+
56
+ # Remove remaining continuation backslashes
57
+ text = text.replace("\\", "")
58
+
59
+ # Convert placeholders back to actual control characters
60
+ text = text.replace(temp_newline, "\n")
61
+ text = text.replace(temp_tab, "\t")
62
+
63
+ return text
64
+
65
+ def encode_text(text):
66
+ """
67
+ Encodes control characters into escape sequences.
68
+ """
69
+ # Replace control characters with escape sequences
70
+ text = text.replace("\n", "\\n")
71
+ text = text.replace("\t", "\\t")
72
+
73
+ return text
74
+
75
+ def format_code(code):
76
+ """
77
+ Format Python code using autopep8 with aggressive settings.
78
+ """
79
+ try:
80
+ # More aggressive formatting with specific options
81
+ formatted_code = autopep8.fix_code(
82
+ code,
83
+ options={
84
+ 'aggressive': 2,
85
+ 'max_line_length': 88,
86
+ 'indent_size': 4
87
+ }
88
+ )
89
+
90
+ # Additional formatting for consistent spacing around parentheses and operators
91
+ # Remove extra spaces inside parentheses
92
+ formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")
93
+
94
+ # Ensure consistent spacing around operators
95
+ for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
96
+ formatted_code = formatted_code.replace(f"{op} ", op + " ")
97
+ formatted_code = formatted_code.replace(f" {op}", " " + op)
98
+
99
+ # Remove extra spaces after function calls
100
+ formatted_code = re.sub(r'(\w+)\s+\(', r'\1(', formatted_code)
101
+
102
+ return formatted_code
103
+ except Exception as e:
104
+ print(f"Error formatting code: {str(e)}")
105
+ return code
106
+
107
+ def fix_common_syntax_issues(code):
108
+ """
109
+ Fix common syntax issues in generated code without modifying indentation.
110
+ """
111
+ # Only fix critical syntax issues while preserving indentation
112
+ lines = code.split('\n')
113
+ fixed_lines = []
114
+
115
+ for line in lines:
116
+ stripped = line.strip()
117
+ # Check if line needs a colon at the end
118
+ if (stripped.startswith('if ') or
119
+ stripped.startswith('elif ') or
120
+ stripped.startswith('else') or
121
+ stripped.startswith('for ') or
122
+ stripped.startswith('while ') or
123
+ stripped.startswith('def ') or
124
+ stripped.startswith('class ')):
125
+
126
+ # Only add colon if it doesn't already end with one and doesn't continue to next line
127
+ if not stripped.endswith(':') and not stripped.endswith('\\'):
128
+ line = line.rstrip() + ':'
129
+
130
+ fixed_lines.append(line)
131
+
132
+ code = '\n'.join(fixed_lines)
133
+
134
+ # Fix mismatched quotes
135
+ quote_chars = ['"', "'"]
136
+ for quote in quote_chars:
137
+ # Count quotes in the code
138
+ if code.count(quote) % 2 != 0:
139
+ # Find incomplete string literals
140
+ lines = code.split('\n')
141
+ for i, line in enumerate(lines):
142
+ # Count quotes in this line
143
+ if line.count(quote) % 2 != 0:
144
+ # Add missing quote at the end of the line
145
+ lines[i] = line.rstrip() + quote
146
+ break
147
+ code = '\n'.join(lines)
148
+
149
+ # Fix missing parentheses in function calls
150
+ pattern = r'(\w+)\s*\([^)]*$' # Function call without closing parenthesis
151
+ if re.search(pattern, code):
152
+ # Add closing parenthesis at the end of the line
153
+ lines = code.split('\n')
154
+ for i, line in enumerate(lines):
155
+ if re.search(pattern, line) and not any(lines[j].strip().startswith(')') for j in range(i+1, min(i+3, len(lines)))):
156
+ lines[i] = line.rstrip() + ')'
157
+ code = '\n'.join(lines)
158
+
159
+ return code
160
+
161
+ def load_example_from_file(example_path):
162
+ """
163
+ Load example from a file with format description_BREAK_code
164
+ """
165
+ try:
166
+ with open(example_path, 'r') as f:
167
+ content = f.read()
168
+
169
+ # Split by the separator
170
+ parts = content.split("_BREAK_")
171
+ if len(parts) == 2:
172
+ description = parts[0].strip()
173
+ code = parts[1].strip()
174
+
175
+ # Replace escape sequences with actual characters
176
+ code = code.replace('\\n', '\n').replace('\\t', '\t')
177
+
178
+ # Fix indentation
179
+ code = normalize_indentation(code)
180
+
181
+ return description, code
182
+ else:
183
+ print(f"Invalid format in example file: {example_path}")
184
+ return "", ""
185
+ except Exception as e:
186
+ print(f"Error loading example file {example_path}: {str(e)}")
187
+ return "", ""
188
+
189
+ def find_example_files():
190
+ """
191
+ Find all raw.in example files in the examples directory
192
+ """
193
+ example_files = glob.glob("examples/*/raw.in")
194
+ return example_files
195
+
196
+ def main():
197
+ # Configure argument parser
198
+ parser = argparse.ArgumentParser(description='Launch a Gradio interface for a fine-tuned CodeT5+ model')
199
+ parser.add_argument('model_bin_path', type=str, help='Path to the fine-tuned model .bin file')
200
+ parser.add_argument('--base_model', type=str, default="Salesforce/codet5p-770m",
201
+ help='Base model name (default: Salesforce/codet5p-770m)')
202
+
203
+ # Parse arguments
204
+ args = parser.parse_args()
205
+
206
+ # Specify the base model you fine-tuned from
207
+ base_model = args.base_model
208
+
209
+ print(f"Using base model: {base_model}")
210
+ print(f"Loading fine-tuned weights from: {args.model_bin_path}")
211
+
212
+ # Load tokenizer from the base model
213
+ print("Loading tokenizer...")
214
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
215
+
216
+ # Load the base model
217
+ print("Loading base model...")
218
+ model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
219
+
220
+ # Load weights from the .bin file
221
+ if os.path.exists(args.model_bin_path):
222
+ print("Loading fine-tuned weights...")
223
+ try:
224
+ # Load your fine-tuned model
225
+ state_dict = torch.load(args.model_bin_path, map_location=torch.device('cpu'))
226
+
227
+ # If the state_dict contains a 'model_state_dict' key (common in some checkpoints)
228
+ if 'model_state_dict' in state_dict:
229
+ state_dict = state_dict['model_state_dict']
230
+
231
+ # Load weights into the model
232
+ model.load_state_dict(state_dict, strict=False)
233
+ print("Fine-tuned model loaded successfully!")
234
+ except Exception as e:
235
+ print(f"Error loading model: {str(e)}")
236
+ print("Using base model without fine-tuning.")
237
+ else:
238
+ print(f"WARNING: Model file not found at {args.model_bin_path}. Using base model.")
239
+
240
+ # State variables to track conversation state
241
+ current_code = None
242
+ # Counter for tracking number of bug generation calls
243
+ bug_counter = 0
244
+
245
+ def generate_bugged_code(description, code, chat_history, is_first_time):
246
+ nonlocal current_code, bug_counter
247
+
248
+ # If we're starting over (is_first_time), reset the counter
249
+ if is_first_time:
250
+ bug_counter = 0
251
+ current_code = None
252
+
253
+ # Increment the counter for each call
254
+ bug_counter += 1
255
+
256
+ # Use counter to decide which code to use
257
+ if bug_counter == 1:
258
+ # First call or new code when restarting - use original code
259
+ input_for_model = code
260
+ input_type = "original"
261
+ else:
262
+ # Subsequent calls - use previously generated bugged code
263
+ if current_code is None:
264
+ return chat_history, gr.update(value=""), False
265
+
266
+ input_for_model = current_code
267
+ input_type = "previous bugged code"
268
+
269
+ # Log for debugging
270
+ print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")
271
+
272
+ # Encode the input code
273
+ encoded_code = encode_text(input_for_model)
274
+ combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
275
+
276
+ # Tokenize input
277
+ inputs = tokenizer.encode(combined_input, return_tensors='pt')
278
+
279
+ # Generate output
280
+ outputs = model.generate(
281
+ inputs,
282
+ max_length=1024,
283
+ num_beams=5,
284
+ early_stopping=True
285
+ )
286
+
287
+ # Decode output
288
+ bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
289
+
290
+ # Clean code (replacing escape sequences with actual control characters)
291
+ bugged_code = clear_text(bugged_code_escaped)
292
+
293
+ # First fix common syntax issues without changing indentation
294
+ bugged_code = fix_common_syntax_issues(bugged_code)
295
+
296
+ # Apply autopep8 formatting
297
+ bugged_code = format_code(bugged_code)
298
+
299
+ # Update current code with the newly generated code
300
+ current_code = bugged_code
301
+
302
+ # Update chat history
303
+ if is_first_time:
304
+ chat_history = []
305
+
306
+ user_message = f"**Description**: {description}"
307
+ if input_type == "original":
308
+ user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
309
+ else:
310
+ user_message += f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
311
+
312
+ ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"
313
+
314
+ chat_history.append((user_message, ai_message))
315
+
316
+ return chat_history, gr.update(value=""), False
317
+
318
+ # Function to handle interface reset
319
+ def reset_interface():
320
+ nonlocal current_code, bug_counter
321
+ current_code = None
322
+ bug_counter = 0
323
+ return [], gr.update(value=""), True
324
+
325
+ # Find example files
326
+ example_files = find_example_files()
327
+ example_names = [f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files)]
328
+
329
+ # Function to load examples
330
+ def load_example(example_index):
331
+ if example_index < len(example_files):
332
+ return load_example_from_file(example_files[example_index])
333
+ return "", ""
334
+
335
+ # Create Gradio interface using Blocks
336
+ with gr.Blocks(title="Software-Fault Injection from NL") as demo:
337
+ gr.Markdown("# Software-Fault Injection from Natural Language")
338
+ gr.Markdown("Generate Python code with specific bugs based on a description and original code.")
339
+
340
+ with gr.Row():
341
+ with gr.Column(scale=2):
342
+ # Input components
343
+ description_input = gr.Textbox(
344
+ label="Bug Description",
345
+ placeholder="Describe the type of bug to introduce...",
346
+ lines=3
347
+ )
348
+ code_input = gr.Code(
349
+ label="Original Code",
350
+ language="python",
351
+ lines=10
352
+ )
353
+
354
+ # Flag to track if it's the first submission
355
+ is_first = gr.State(True)
356
+
357
+ # Submit button
358
+ submit_btn = gr.Button("Generate Bugged Code")
359
+
360
+ # Reset button
361
+ reset_btn = gr.Button("Start Over")
362
+
363
+ # Examples
364
+ gr.Markdown("### Examples")
365
+ example_buttons = [gr.Button(name) for name in example_names]
366
+
367
+ with gr.Column(scale=3):
368
+ # Chat history
369
+ chat_output = gr.Chatbot(
370
+ label="Conversation",
371
+ height=500
372
+ )
373
+
374
+ # Connect example buttons
375
+ for i, btn in enumerate(example_buttons):
376
+ btn.click(
377
+ fn=lambda i=i: load_example(i),
378
+ outputs=[description_input, code_input]
379
+ )
380
+
381
+ # Connect submit button
382
+ submit_btn.click(
383
+ fn=generate_bugged_code,
384
+ inputs=[description_input, code_input, chat_output, is_first],
385
+ outputs=[chat_output, description_input, is_first]
386
+ )
387
+
388
+ # Connect reset button
389
+ reset_btn.click(
390
+ fn=reset_interface,
391
+ outputs=[chat_output, description_input, is_first]
392
+ )
393
+
394
+ # Launch interface
395
+ print("Launching Gradio interface...")
396
+ demo.launch()
397
+
398
+ if __name__ == "__main__":
399
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch