File size: 14,012 Bytes
be20041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import torch
import os
import sys
import argparse
import autopep8
import glob
import re

def normalize_indentation(code):
    """
    Normalize indentation in example code by removing excessive tabs.
    Also removes any backslash characters.
    """
    # Remove backslash characters
    code = code.replace('\\', '')
    
    lines = code.split('\n')
    if not lines:
        return ""
    
    # Check if we have a function with excessive indentation
    fixed_lines = []
    indent_fix_mode = False
    
    for i, line in enumerate(lines):
        if line.strip().startswith('def '):
            fixed_lines.append(line)
            indent_fix_mode = True
        elif indent_fix_mode and line.strip():
            # For indented lines in a function
            if line.startswith('\t\t'):  # Two tabs
                fixed_lines.append('\t' + line[2:])  # Replace with one tab
            elif line.startswith('        '):  # 8 spaces (2 levels)
                fixed_lines.append('    ' + line[8:])  # Replace with 4 spaces
            else:
                fixed_lines.append(line)
        else:
            fixed_lines.append(line)
    
    return '\n'.join(fixed_lines)

def clear_text(text):
    """
    Cleans text from escape sequences while preserving original formatting.
    """
    # First, replace \\n with a special placeholder
    temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
    temp_tab = "TEMP_TAB_PLACEHOLDER"
    
    # Temporarily replace escape sequences
    text = text.replace("\\n", temp_newline)
    text = text.replace("\\t", temp_tab)
    
    # Remove remaining continuation backslashes
    text = text.replace("\\", "")
    
    # Convert placeholders back to actual control characters
    text = text.replace(temp_newline, "\n")
    text = text.replace(temp_tab, "\t")
    
    return text

def encode_text(text):
    """
    Encodes control characters into escape sequences.
    """
    # Replace control characters with escape sequences
    text = text.replace("\n", "\\n")
    text = text.replace("\t", "\\t")
    
    return text

def format_code(code):
    """
    Format Python code using autopep8 with aggressive settings.
    """
    try:
        # More aggressive formatting with specific options
        formatted_code = autopep8.fix_code(
            code, 
            options={
                'aggressive': 2,
                'max_line_length': 88,
                'indent_size': 4
            }
        )
        
        # Additional formatting for consistent spacing around parentheses and operators
        # Remove extra spaces inside parentheses
        formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")
        
        # Ensure consistent spacing around operators
        for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
            formatted_code = formatted_code.replace(f"{op} ", op + " ")
            formatted_code = formatted_code.replace(f" {op}", " " + op)
        
        # Remove extra spaces after function calls
        formatted_code = re.sub(r'(\w+)\s+\(', r'\1(', formatted_code)
        
        return formatted_code
    except Exception as e:
        print(f"Error formatting code: {str(e)}")
        return code

def fix_common_syntax_issues(code):
    """
    Fix common syntax issues in generated code without modifying indentation.
    """
    # Only fix critical syntax issues while preserving indentation
    lines = code.split('\n')
    fixed_lines = []
    
    for line in lines:
        stripped = line.strip()
        # Check if line needs a colon at the end
        if (stripped.startswith('if ') or 
            stripped.startswith('elif ') or 
            stripped.startswith('else') or
            stripped.startswith('for ') or
            stripped.startswith('while ') or
            stripped.startswith('def ') or
            stripped.startswith('class ')):
            
            # Only add colon if it doesn't already end with one and doesn't continue to next line
            if not stripped.endswith(':') and not stripped.endswith('\\'):
                line = line.rstrip() + ':'
        
        fixed_lines.append(line)
    
    code = '\n'.join(fixed_lines)
    
    # Fix mismatched quotes
    quote_chars = ['"', "'"]
    for quote in quote_chars:
        # Count quotes in the code
        if code.count(quote) % 2 != 0:
            # Find incomplete string literals
            lines = code.split('\n')
            for i, line in enumerate(lines):
                # Count quotes in this line
                if line.count(quote) % 2 != 0:
                    # Add missing quote at the end of the line
                    lines[i] = line.rstrip() + quote
                    break
            code = '\n'.join(lines)
    
    # Fix missing parentheses in function calls
    pattern = r'(\w+)\s*\([^)]*$'  # Function call without closing parenthesis
    if re.search(pattern, code):
        # Add closing parenthesis at the end of the line
        lines = code.split('\n')
        for i, line in enumerate(lines):
            if re.search(pattern, line) and not any(lines[j].strip().startswith(')') for j in range(i+1, min(i+3, len(lines)))):
                lines[i] = line.rstrip() + ')'
        code = '\n'.join(lines)
    
    return code

def load_example_from_file(example_path):
    """
    Load example from a file with format description_BREAK_code
    """
    try:
        with open(example_path, 'r') as f:
            content = f.read()
        
        # Split by the separator
        parts = content.split("_BREAK_")
        if len(parts) == 2:
            description = parts[0].strip()
            code = parts[1].strip()
            
            # Replace escape sequences with actual characters
            code = code.replace('\\n', '\n').replace('\\t', '\t')
            
            # Fix indentation
            code = normalize_indentation(code)
            
            return description, code
        else:
            print(f"Invalid format in example file: {example_path}")
            return "", ""
    except Exception as e:
        print(f"Error loading example file {example_path}: {str(e)}")
        return "", ""

def find_example_files():
    """
    Find all raw.in example files in the examples directory
    """
    example_files = glob.glob("examples/*/raw.in")
    return example_files

def main():
    # Configure argument parser
    parser = argparse.ArgumentParser(description='Launch a Gradio interface for a fine-tuned CodeT5+ model')
    parser.add_argument('model_bin_path', type=str, help='Path to the fine-tuned model .bin file')
    parser.add_argument('--base_model', type=str, default="Salesforce/codet5p-770m",
                       help='Base model name (default: Salesforce/codet5p-770m)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Specify the base model you fine-tuned from
    base_model = args.base_model
    
    print(f"Using base model: {base_model}")
    print(f"Loading fine-tuned weights from: {args.model_bin_path}")

    # Load tokenizer from the base model
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    # Load the base model
    print("Loading base model...")
    model = AutoModelForSeq2SeqLM.from_pretrained(base_model)

    # Load weights from the .bin file
    if os.path.exists(args.model_bin_path):
        print("Loading fine-tuned weights...")
        try:
            # Load your fine-tuned model
            state_dict = torch.load(args.model_bin_path, map_location=torch.device('cpu'))
            
            # If the state_dict contains a 'model_state_dict' key (common in some checkpoints)
            if 'model_state_dict' in state_dict:
                state_dict = state_dict['model_state_dict']
            
            # Load weights into the model
            model.load_state_dict(state_dict, strict=False)
            print("Fine-tuned model loaded successfully!")
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            print("Using base model without fine-tuning.")
    else:
        print(f"WARNING: Model file not found at {args.model_bin_path}. Using base model.")

    # State variables to track conversation state
    current_code = None
    # Counter for tracking number of bug generation calls
    bug_counter = 0

    def generate_bugged_code(description, code, chat_history, is_first_time):
        nonlocal current_code, bug_counter
        
        # If we're starting over (is_first_time), reset the counter
        if is_first_time:
            bug_counter = 0
            current_code = None
            
        # Increment the counter for each call
        bug_counter += 1
        
        # Use counter to decide which code to use
        if bug_counter == 1:
            # First call or new code when restarting - use original code
            input_for_model = code
            input_type = "original"
        else:
            # Subsequent calls - use previously generated bugged code
            if current_code is None:
                return chat_history, gr.update(value=""), False
            
            input_for_model = current_code
            input_type = "previous bugged code"
        
        # Log for debugging
        print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")
        
        # Encode the input code
        encoded_code = encode_text(input_for_model)
        combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
        
        # Tokenize input
        inputs = tokenizer.encode(combined_input, return_tensors='pt')
        
        # Generate output
        outputs = model.generate(
            inputs, 
            max_length=1024,
            num_beams=5,
            early_stopping=True
        )
        
        # Decode output
        bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean code (replacing escape sequences with actual control characters)
        bugged_code = clear_text(bugged_code_escaped)
        
        # First fix common syntax issues without changing indentation
        bugged_code = fix_common_syntax_issues(bugged_code)
        
        # Apply autopep8 formatting
        bugged_code = format_code(bugged_code)
        
        # Update current code with the newly generated code
        current_code = bugged_code
        
        # Update chat history
        if is_first_time:
            chat_history = []
            
        user_message = f"**Description**: {description}"
        if input_type == "original":
            user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
        else:
            user_message += f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
            
        ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"
        
        chat_history.append((user_message, ai_message))
        
        return chat_history, gr.update(value=""), False

    # Function to handle interface reset
    def reset_interface():
        nonlocal current_code, bug_counter
        current_code = None
        bug_counter = 0
        return [], gr.update(value=""), True

    # Find example files
    example_files = find_example_files()
    example_names = [f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files)]

    # Function to load examples
    def load_example(example_index):
        if example_index < len(example_files):
            return load_example_from_file(example_files[example_index])
        return "", ""

    # Create Gradio interface using Blocks
    with gr.Blocks(title="Software-Fault Injection from NL") as demo:
        gr.Markdown("# Software-Fault Injection from Natural Language")
        gr.Markdown("Generate Python code with specific bugs based on a description and original code.")
        
        with gr.Row():
            with gr.Column(scale=2):
                # Input components
                description_input = gr.Textbox(
                    label="Bug Description", 
                    placeholder="Describe the type of bug to introduce...",
                    lines=3
                )
                code_input = gr.Code(
                    label="Original Code", 
                    language="python",
                    lines=10
                )
                
                # Flag to track if it's the first submission
                is_first = gr.State(True)
                
                # Submit button
                submit_btn = gr.Button("Generate Bugged Code")
                
                # Reset button
                reset_btn = gr.Button("Start Over")
                
                # Examples
                gr.Markdown("### Examples")
                example_buttons = [gr.Button(name) for name in example_names]
                
            with gr.Column(scale=3):
                # Chat history
                chat_output = gr.Chatbot(
                    label="Conversation",
                    height=500
                )

        # Connect example buttons
        for i, btn in enumerate(example_buttons):
            btn.click(
                fn=lambda i=i: load_example(i),
                outputs=[description_input, code_input]
            )

        # Connect submit button
        submit_btn.click(
            fn=generate_bugged_code,
            inputs=[description_input, code_input, chat_output, is_first],
            outputs=[chat_output, description_input, is_first]
        )
        
        # Connect reset button
        reset_btn.click(
            fn=reset_interface,
            outputs=[chat_output, description_input, is_first]
        )

    # Launch interface
    print("Launching Gradio interface...")
    demo.launch()

if __name__ == "__main__":
    main()