|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import autopep8 |
|
|
import glob |
|
|
import re |
|
|
|
|
|
def normalize_indentation(code): |
|
|
""" |
|
|
Normalize indentation in example code by removing excessive tabs. |
|
|
Also removes any backslash characters. |
|
|
""" |
|
|
|
|
|
code = code.replace('\\', '') |
|
|
|
|
|
lines = code.split('\n') |
|
|
if not lines: |
|
|
return "" |
|
|
|
|
|
|
|
|
fixed_lines = [] |
|
|
indent_fix_mode = False |
|
|
|
|
|
for i, line in enumerate(lines): |
|
|
if line.strip().startswith('def '): |
|
|
fixed_lines.append(line) |
|
|
indent_fix_mode = True |
|
|
elif indent_fix_mode and line.strip(): |
|
|
|
|
|
if line.startswith('\t\t'): |
|
|
fixed_lines.append('\t' + line[2:]) |
|
|
elif line.startswith(' '): |
|
|
fixed_lines.append(' ' + line[8:]) |
|
|
else: |
|
|
fixed_lines.append(line) |
|
|
else: |
|
|
fixed_lines.append(line) |
|
|
|
|
|
return '\n'.join(fixed_lines) |
|
|
|
|
|
def clear_text(text): |
|
|
""" |
|
|
Cleans text from escape sequences while preserving original formatting. |
|
|
""" |
|
|
|
|
|
temp_newline = "TEMP_NEWLINE_PLACEHOLDER" |
|
|
temp_tab = "TEMP_TAB_PLACEHOLDER" |
|
|
|
|
|
|
|
|
text = text.replace("\\n", temp_newline) |
|
|
text = text.replace("\\t", temp_tab) |
|
|
|
|
|
|
|
|
text = text.replace("\\", "") |
|
|
|
|
|
|
|
|
text = text.replace(temp_newline, "\n") |
|
|
text = text.replace(temp_tab, "\t") |
|
|
|
|
|
return text |
|
|
|
|
|
def encode_text(text): |
|
|
""" |
|
|
Encodes control characters into escape sequences. |
|
|
""" |
|
|
|
|
|
text = text.replace("\n", "\\n") |
|
|
text = text.replace("\t", "\\t") |
|
|
|
|
|
return text |
|
|
|
|
|
def format_code(code): |
|
|
""" |
|
|
Format Python code using autopep8 with aggressive settings. |
|
|
""" |
|
|
try: |
|
|
|
|
|
formatted_code = autopep8.fix_code( |
|
|
code, |
|
|
options={ |
|
|
'aggressive': 2, |
|
|
'max_line_length': 88, |
|
|
'indent_size': 4 |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") |
|
|
|
|
|
|
|
|
for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: |
|
|
formatted_code = formatted_code.replace(f"{op} ", op + " ") |
|
|
formatted_code = formatted_code.replace(f" {op}", " " + op) |
|
|
|
|
|
|
|
|
formatted_code = re.sub(r'(\w+)\s+\(', r'\1(', formatted_code) |
|
|
|
|
|
return formatted_code |
|
|
except Exception as e: |
|
|
print(f"Error formatting code: {str(e)}") |
|
|
return code |
|
|
|
|
|
def fix_common_syntax_issues(code): |
|
|
""" |
|
|
Fix common syntax issues in generated code without modifying indentation. |
|
|
""" |
|
|
|
|
|
lines = code.split('\n') |
|
|
fixed_lines = [] |
|
|
|
|
|
for line in lines: |
|
|
stripped = line.strip() |
|
|
|
|
|
if (stripped.startswith('if ') or |
|
|
stripped.startswith('elif ') or |
|
|
stripped.startswith('else') or |
|
|
stripped.startswith('for ') or |
|
|
stripped.startswith('while ') or |
|
|
stripped.startswith('def ') or |
|
|
stripped.startswith('class ')): |
|
|
|
|
|
|
|
|
if not stripped.endswith(':') and not stripped.endswith('\\'): |
|
|
line = line.rstrip() + ':' |
|
|
|
|
|
fixed_lines.append(line) |
|
|
|
|
|
code = '\n'.join(fixed_lines) |
|
|
|
|
|
|
|
|
quote_chars = ['"', "'"] |
|
|
for quote in quote_chars: |
|
|
|
|
|
if code.count(quote) % 2 != 0: |
|
|
|
|
|
lines = code.split('\n') |
|
|
for i, line in enumerate(lines): |
|
|
|
|
|
if line.count(quote) % 2 != 0: |
|
|
|
|
|
lines[i] = line.rstrip() + quote |
|
|
break |
|
|
code = '\n'.join(lines) |
|
|
|
|
|
|
|
|
pattern = r'(\w+)\s*\([^)]*$' |
|
|
if re.search(pattern, code): |
|
|
|
|
|
lines = code.split('\n') |
|
|
for i, line in enumerate(lines): |
|
|
if re.search(pattern, line) and not any(lines[j].strip().startswith(')') for j in range(i+1, min(i+3, len(lines)))): |
|
|
lines[i] = line.rstrip() + ')' |
|
|
code = '\n'.join(lines) |
|
|
|
|
|
return code |
|
|
|
|
|
def load_example_from_file(example_path): |
|
|
""" |
|
|
Load example from a file with format description_BREAK_code |
|
|
""" |
|
|
try: |
|
|
with open(example_path, 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
|
|
|
parts = content.split("_BREAK_") |
|
|
if len(parts) == 2: |
|
|
description = parts[0].strip() |
|
|
code = parts[1].strip() |
|
|
|
|
|
|
|
|
code = code.replace('\\n', '\n').replace('\\t', '\t') |
|
|
|
|
|
|
|
|
code = normalize_indentation(code) |
|
|
|
|
|
return description, code |
|
|
else: |
|
|
print(f"Invalid format in example file: {example_path}") |
|
|
return "", "" |
|
|
except Exception as e: |
|
|
print(f"Error loading example file {example_path}: {str(e)}") |
|
|
return "", "" |
|
|
|
|
|
def find_example_files(): |
|
|
""" |
|
|
Find all raw.in example files in the examples directory |
|
|
""" |
|
|
example_files = glob.glob("examples/*/raw.in") |
|
|
return example_files |
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Launch a Gradio interface for a fine-tuned CodeT5+ model') |
|
|
parser.add_argument('model_bin_path', type=str, help='Path to the fine-tuned model .bin file') |
|
|
parser.add_argument('--base_model', type=str, default="Salesforce/codet5p-770m", |
|
|
help='Base model name (default: Salesforce/codet5p-770m)') |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
base_model = args.base_model |
|
|
|
|
|
print(f"Using base model: {base_model}") |
|
|
print(f"Loading fine-tuned weights from: {args.model_bin_path}") |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
|
|
|
|
|
|
|
print("Loading base model...") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(base_model) |
|
|
|
|
|
|
|
|
if os.path.exists(args.model_bin_path): |
|
|
print("Loading fine-tuned weights...") |
|
|
try: |
|
|
|
|
|
state_dict = torch.load(args.model_bin_path, map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in state_dict: |
|
|
state_dict = state_dict['model_state_dict'] |
|
|
|
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
print("Fine-tuned model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {str(e)}") |
|
|
print("Using base model without fine-tuning.") |
|
|
else: |
|
|
print(f"WARNING: Model file not found at {args.model_bin_path}. Using base model.") |
|
|
|
|
|
|
|
|
current_code = None |
|
|
|
|
|
bug_counter = 0 |
|
|
|
|
|
def generate_bugged_code(description, code, chat_history, is_first_time): |
|
|
nonlocal current_code, bug_counter |
|
|
|
|
|
|
|
|
if is_first_time: |
|
|
bug_counter = 0 |
|
|
current_code = None |
|
|
|
|
|
|
|
|
bug_counter += 1 |
|
|
|
|
|
|
|
|
if bug_counter == 1: |
|
|
|
|
|
input_for_model = code |
|
|
input_type = "original" |
|
|
else: |
|
|
|
|
|
if current_code is None: |
|
|
return chat_history, gr.update(value=""), False |
|
|
|
|
|
input_for_model = current_code |
|
|
input_type = "previous bugged code" |
|
|
|
|
|
|
|
|
print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") |
|
|
|
|
|
|
|
|
encoded_code = encode_text(input_for_model) |
|
|
combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" |
|
|
|
|
|
|
|
|
inputs = tokenizer.encode(combined_input, return_tensors='pt') |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
inputs, |
|
|
max_length=1024, |
|
|
num_beams=5, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
bugged_code = clear_text(bugged_code_escaped) |
|
|
|
|
|
|
|
|
bugged_code = fix_common_syntax_issues(bugged_code) |
|
|
|
|
|
|
|
|
bugged_code = format_code(bugged_code) |
|
|
|
|
|
|
|
|
current_code = bugged_code |
|
|
|
|
|
|
|
|
if is_first_time: |
|
|
chat_history = [] |
|
|
|
|
|
user_message = f"**Description**: {description}" |
|
|
if input_type == "original": |
|
|
user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" |
|
|
else: |
|
|
user_message += f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" |
|
|
|
|
|
ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" |
|
|
|
|
|
chat_history.append((user_message, ai_message)) |
|
|
|
|
|
return chat_history, gr.update(value=""), False |
|
|
|
|
|
|
|
|
def reset_interface(): |
|
|
nonlocal current_code, bug_counter |
|
|
current_code = None |
|
|
bug_counter = 0 |
|
|
return [], gr.update(value=""), True |
|
|
|
|
|
|
|
|
example_files = find_example_files() |
|
|
example_names = [f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files)] |
|
|
|
|
|
|
|
|
def load_example(example_index): |
|
|
if example_index < len(example_files): |
|
|
return load_example_from_file(example_files[example_index]) |
|
|
return "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Software-Fault Injection from NL") as demo: |
|
|
gr.Markdown("# Software-Fault Injection from Natural Language") |
|
|
gr.Markdown("Generate Python code with specific bugs based on a description and original code.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
|
|
|
description_input = gr.Textbox( |
|
|
label="Bug Description", |
|
|
placeholder="Describe the type of bug to introduce...", |
|
|
lines=3 |
|
|
) |
|
|
code_input = gr.Code( |
|
|
label="Original Code", |
|
|
language="python", |
|
|
lines=10 |
|
|
) |
|
|
|
|
|
|
|
|
is_first = gr.State(True) |
|
|
|
|
|
|
|
|
submit_btn = gr.Button("Generate Bugged Code") |
|
|
|
|
|
|
|
|
reset_btn = gr.Button("Start Over") |
|
|
|
|
|
|
|
|
gr.Markdown("### Examples") |
|
|
example_buttons = [gr.Button(name) for name in example_names] |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
chat_output = gr.Chatbot( |
|
|
label="Conversation", |
|
|
height=500 |
|
|
) |
|
|
|
|
|
|
|
|
for i, btn in enumerate(example_buttons): |
|
|
btn.click( |
|
|
fn=lambda i=i: load_example(i), |
|
|
outputs=[description_input, code_input] |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate_bugged_code, |
|
|
inputs=[description_input, code_input, chat_output, is_first], |
|
|
outputs=[chat_output, description_input, is_first] |
|
|
) |
|
|
|
|
|
|
|
|
reset_btn.click( |
|
|
fn=reset_interface, |
|
|
outputs=[chat_output, description_input, is_first] |
|
|
) |
|
|
|
|
|
|
|
|
print("Launching Gradio interface...") |
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |