from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import torch import os import sys import argparse import autopep8 import glob import re def normalize_indentation(code): """ Normalize indentation in example code by removing excessive tabs. Also removes any backslash characters. """ # Remove backslash characters code = code.replace('\\', '') lines = code.split('\n') if not lines: return "" # Check if we have a function with excessive indentation fixed_lines = [] indent_fix_mode = False for i, line in enumerate(lines): if line.strip().startswith('def '): fixed_lines.append(line) indent_fix_mode = True elif indent_fix_mode and line.strip(): # For indented lines in a function if line.startswith('\t\t'): # Two tabs fixed_lines.append('\t' + line[2:]) # Replace with one tab elif line.startswith(' '): # 8 spaces (2 levels) fixed_lines.append(' ' + line[8:]) # Replace with 4 spaces else: fixed_lines.append(line) else: fixed_lines.append(line) return '\n'.join(fixed_lines) def clear_text(text): """ Cleans text from escape sequences while preserving original formatting. """ # First, replace \\n with a special placeholder temp_newline = "TEMP_NEWLINE_PLACEHOLDER" temp_tab = "TEMP_TAB_PLACEHOLDER" # Temporarily replace escape sequences text = text.replace("\\n", temp_newline) text = text.replace("\\t", temp_tab) # Remove remaining continuation backslashes text = text.replace("\\", "") # Convert placeholders back to actual control characters text = text.replace(temp_newline, "\n") text = text.replace(temp_tab, "\t") return text def encode_text(text): """ Encodes control characters into escape sequences. """ # Replace control characters with escape sequences text = text.replace("\n", "\\n") text = text.replace("\t", "\\t") return text def format_code(code): """ Format Python code using autopep8 with aggressive settings. """ try: # More aggressive formatting with specific options formatted_code = autopep8.fix_code( code, options={ 'aggressive': 2, 'max_line_length': 88, 'indent_size': 4 } ) # Additional formatting for consistent spacing around parentheses and operators # Remove extra spaces inside parentheses formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") # Ensure consistent spacing around operators for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: formatted_code = formatted_code.replace(f"{op} ", op + " ") formatted_code = formatted_code.replace(f" {op}", " " + op) # Remove extra spaces after function calls formatted_code = re.sub(r'(\w+)\s+\(', r'\1(', formatted_code) return formatted_code except Exception as e: print(f"Error formatting code: {str(e)}") return code def fix_common_syntax_issues(code): """ Fix common syntax issues in generated code without modifying indentation. """ # Only fix critical syntax issues while preserving indentation lines = code.split('\n') fixed_lines = [] for line in lines: stripped = line.strip() # Check if line needs a colon at the end if (stripped.startswith('if ') or stripped.startswith('elif ') or stripped.startswith('else') or stripped.startswith('for ') or stripped.startswith('while ') or stripped.startswith('def ') or stripped.startswith('class ')): # Only add colon if it doesn't already end with one and doesn't continue to next line if not stripped.endswith(':') and not stripped.endswith('\\'): line = line.rstrip() + ':' fixed_lines.append(line) code = '\n'.join(fixed_lines) # Fix mismatched quotes quote_chars = ['"', "'"] for quote in quote_chars: # Count quotes in the code if code.count(quote) % 2 != 0: # Find incomplete string literals lines = code.split('\n') for i, line in enumerate(lines): # Count quotes in this line if line.count(quote) % 2 != 0: # Add missing quote at the end of the line lines[i] = line.rstrip() + quote break code = '\n'.join(lines) # Fix missing parentheses in function calls pattern = r'(\w+)\s*\([^)]*$' # Function call without closing parenthesis if re.search(pattern, code): # Add closing parenthesis at the end of the line lines = code.split('\n') for i, line in enumerate(lines): if re.search(pattern, line) and not any(lines[j].strip().startswith(')') for j in range(i+1, min(i+3, len(lines)))): lines[i] = line.rstrip() + ')' code = '\n'.join(lines) return code def load_example_from_file(example_path): """ Load example from a file with format description_BREAK_code """ try: with open(example_path, 'r') as f: content = f.read() # Split by the separator parts = content.split("_BREAK_") if len(parts) == 2: description = parts[0].strip() code = parts[1].strip() # Replace escape sequences with actual characters code = code.replace('\\n', '\n').replace('\\t', '\t') # Fix indentation code = normalize_indentation(code) return description, code else: print(f"Invalid format in example file: {example_path}") return "", "" except Exception as e: print(f"Error loading example file {example_path}: {str(e)}") return "", "" def find_example_files(): """ Find all raw.in example files in the examples directory """ example_files = glob.glob("examples/*/raw.in") return example_files def main(): # Configure argument parser parser = argparse.ArgumentParser(description='Launch a Gradio interface for a fine-tuned CodeT5+ model') parser.add_argument('model_bin_path', type=str, help='Path to the fine-tuned model .bin file') parser.add_argument('--base_model', type=str, default="Salesforce/codet5p-770m", help='Base model name (default: Salesforce/codet5p-770m)') # Parse arguments args = parser.parse_args() # Specify the base model you fine-tuned from base_model = args.base_model print(f"Using base model: {base_model}") print(f"Loading fine-tuned weights from: {args.model_bin_path}") # Load tokenizer from the base model print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model) # Load the base model print("Loading base model...") model = AutoModelForSeq2SeqLM.from_pretrained(base_model) # Load weights from the .bin file if os.path.exists(args.model_bin_path): print("Loading fine-tuned weights...") try: # Load your fine-tuned model state_dict = torch.load(args.model_bin_path, map_location=torch.device('cpu')) # If the state_dict contains a 'model_state_dict' key (common in some checkpoints) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] # Load weights into the model model.load_state_dict(state_dict, strict=False) print("Fine-tuned model loaded successfully!") except Exception as e: print(f"Error loading model: {str(e)}") print("Using base model without fine-tuning.") else: print(f"WARNING: Model file not found at {args.model_bin_path}. Using base model.") # State variables to track conversation state current_code = None # Counter for tracking number of bug generation calls bug_counter = 0 def generate_bugged_code(description, code, chat_history, is_first_time): nonlocal current_code, bug_counter # If we're starting over (is_first_time), reset the counter if is_first_time: bug_counter = 0 current_code = None # Increment the counter for each call bug_counter += 1 # Use counter to decide which code to use if bug_counter == 1: # First call or new code when restarting - use original code input_for_model = code input_type = "original" else: # Subsequent calls - use previously generated bugged code if current_code is None: return chat_history, gr.update(value=""), False input_for_model = current_code input_type = "previous bugged code" # Log for debugging print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") # Encode the input code encoded_code = encode_text(input_for_model) combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" # Tokenize input inputs = tokenizer.encode(combined_input, return_tensors='pt') # Generate output outputs = model.generate( inputs, max_length=1024, num_beams=5, early_stopping=True ) # Decode output bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean code (replacing escape sequences with actual control characters) bugged_code = clear_text(bugged_code_escaped) # First fix common syntax issues without changing indentation bugged_code = fix_common_syntax_issues(bugged_code) # Apply autopep8 formatting bugged_code = format_code(bugged_code) # Update current code with the newly generated code current_code = bugged_code # Update chat history if is_first_time: chat_history = [] user_message = f"**Description**: {description}" if input_type == "original": user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" else: user_message += f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" chat_history.append((user_message, ai_message)) return chat_history, gr.update(value=""), False # Function to handle interface reset def reset_interface(): nonlocal current_code, bug_counter current_code = None bug_counter = 0 return [], gr.update(value=""), True # Find example files example_files = find_example_files() example_names = [f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files)] # Function to load examples def load_example(example_index): if example_index < len(example_files): return load_example_from_file(example_files[example_index]) return "", "" # Create Gradio interface using Blocks with gr.Blocks(title="Software-Fault Injection from NL") as demo: gr.Markdown("# Software-Fault Injection from Natural Language") gr.Markdown("Generate Python code with specific bugs based on a description and original code.") with gr.Row(): with gr.Column(scale=2): # Input components description_input = gr.Textbox( label="Bug Description", placeholder="Describe the type of bug to introduce...", lines=3 ) code_input = gr.Code( label="Original Code", language="python", lines=10 ) # Flag to track if it's the first submission is_first = gr.State(True) # Submit button submit_btn = gr.Button("Generate Bugged Code") # Reset button reset_btn = gr.Button("Start Over") # Examples gr.Markdown("### Examples") example_buttons = [gr.Button(name) for name in example_names] with gr.Column(scale=3): # Chat history chat_output = gr.Chatbot( label="Conversation", height=500 ) # Connect example buttons for i, btn in enumerate(example_buttons): btn.click( fn=lambda i=i: load_example(i), outputs=[description_input, code_input] ) # Connect submit button submit_btn.click( fn=generate_bugged_code, inputs=[description_input, code_input, chat_output, is_first], outputs=[chat_output, description_input, is_first] ) # Connect reset button reset_btn.click( fn=reset_interface, outputs=[chat_output, description_input, is_first] ) # Launch interface print("Launching Gradio interface...") demo.launch() if __name__ == "__main__": main()