File size: 10,863 Bytes
be20041
 
 
 
 
 
ec5045e
d6272b6
 
ec5045e
 
 
 
be20041
 
 
 
 
 
ec5045e
 
 
be20041
 
ec5045e
be20041
 
ec5045e
be20041
ec5045e
be20041
 
 
 
ec5045e
 
 
 
be20041
 
 
 
ec5045e
 
 
be20041
 
 
 
 
 
 
ec5045e
be20041
 
ec5045e
be20041
ec5045e
be20041
 
ec5045e
be20041
 
ec5045e
be20041
 
 
 
 
 
 
 
ec5045e
be20041
 
 
 
 
 
ec5045e
be20041
ec5045e
 
 
 
be20041
ec5045e
be20041
 
ec5045e
be20041
 
 
ec5045e
 
 
be20041
 
 
 
 
ec5045e
be20041
 
 
 
ec5045e
be20041
ec5045e
be20041
 
ec5045e
 
 
 
 
 
 
 
 
 
 
 
be20041
ec5045e
 
 
be20041
 
 
 
ec5045e
be20041
 
 
 
ec5045e
 
be20041
ec5045e
be20041
ec5045e
be20041
ec5045e
 
 
 
 
 
 
be20041
 
ec5045e
be20041
 
ec5045e
 
 
be20041
 
ec5045e
be20041
ec5045e
be20041
 
 
 
ec5045e
 
be20041
ec5045e
be20041
 
 
 
 
 
 
 
ec5045e
be20041
 
ec5045e
be20041
 
 
 
ec5045e
 
 
 
 
d6272b6
fbf65ef
d6272b6
b761fa3
fbf65ef
 
b761fa3
 
 
d6272b6
 
fbf65ef
d6272b6
 
 
 
 
fbf65ef
d6272b6
 
 
 
 
fbf65ef
d6272b6
 
23479ca
ec5045e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439b6b
 
 
ec5045e
 
 
8439b6b
ec5045e
 
 
 
 
 
be20041
ec5045e
 
 
 
 
 
 
 
 
 
9540433
 
 
 
 
 
ec5045e
9540433
 
 
 
 
 
 
 
 
 
 
 
 
 
ec5045e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439b6b
 
 
 
ec5045e
 
be20041
ec5045e
9540433
8439b6b
ec5045e
 
be20041
 
ec5045e
be20041
 
ec5045e
 
 
 
 
 
 
 
 
 
 
be20041
 
ec5045e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be20041
 
ec5045e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be20041
 
ec5045e
 
 
 
 
 
 
 
 
 
be20041
ec5045e
2456544
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import torch
import autopep8
import glob
import re
import os
from huggingface_hub import hf_hub_download


# ==========================
#  Utility functions
# ==========================

def normalize_indentation(code):
    """
    Normalize indentation in example code by removing excessive tabs.
    Also removes any backslash characters.
    """
    code = code.replace("\\", "")

    lines = code.split("\n")
    if not lines:
        return ""

    fixed_lines = []
    indent_fix_mode = False

    for i, line in enumerate(lines):
        if line.strip().startswith("def "):
            fixed_lines.append(line)
            indent_fix_mode = True
        elif indent_fix_mode and line.strip():
            # For indented lines in a function
            if line.startswith("\t\t"):  # Two tabs
                fixed_lines.append("\t" + line[2:])  # Replace with one tab
            elif line.startswith("        "):  # 8 spaces (2 levels)
                fixed_lines.append("    " + line[8:])  # Replace with 4 spaces
            else:
                fixed_lines.append(line)
        else:
            fixed_lines.append(line)

    return "\n".join(fixed_lines)


def clear_text(text):
    """
    Cleans text from escape sequences while preserving original formatting.
    """
    temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
    temp_tab = "TEMP_TAB_PLACEHOLDER"

    text = text.replace("\\n", temp_newline)
    text = text.replace("\\t", temp_tab)

    text = text.replace("\\", "")

    text = text.replace(temp_newline, "\n")
    text = text.replace(temp_tab, "\t")

    return text


def encode_text(text):
    """
    Encodes control characters into escape sequences.
    """
    text = text.replace("\n", "\\n")
    text = text.replace("\t", "\\t")
    return text


def format_code(code):
    """
    Format Python code using autopep8 with aggressive settings.
    """
    try:
        formatted_code = autopep8.fix_code(
            code,
            options={
                "aggressive": 2,
                "max_line_length": 88,
                "indent_size": 4,
            },
        )

        # Additional formatting for consistent spacing around parentheses and operators
        formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")

        for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
            formatted_code = formatted_code.replace(f"{op} ", op + " ")
            formatted_code = formatted_code.replace(f" {op}", " " + op)

        formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code)

        return formatted_code
    except Exception as e:
        print(f"Error formatting code: {str(e)}")
        return code


def fix_common_syntax_issues(code):
    """
    Fix common syntax issues in generated code without modifying indentation.
    """
    lines = code.split("\n")
    fixed_lines = []

    for line in lines:
        stripped = line.strip()
        if (
            stripped.startswith("if ")
            or stripped.startswith("elif ")
            or stripped.startswith("else")
            or stripped.startswith("for ")
            or stripped.startswith("while ")
            or stripped.startswith("def ")
            or stripped.startswith("class ")
        ):
            if not stripped.endswith(":") and not stripped.endswith("\\"):
                line = line.rstrip() + ":"

        fixed_lines.append(line)

    code = "\n".join(fixed_lines)

    # Fix mismatched quotes
    quote_chars = ['"', "'"]
    for quote in quote_chars:
        if code.count(quote) % 2 != 0:
            lines = code.split("\n")
            for i, line in enumerate(lines):
                if line.count(quote) % 2 != 0:
                    lines[i] = line.rstrip() + quote
                    break
            code = "\n".join(lines)

    # Fix missing parentheses in function calls
    pattern = r"(\w+)\s*\([^)]*$"
    if re.search(pattern, code):
        lines = code.split("\n")
        for i, line in enumerate(lines):
            if re.search(pattern, line) and not any(
                lines[j].strip().startswith(")")
                for j in range(i + 1, min(i + 3, len(lines)))
            ):
                lines[i] = line.rstrip() + ")"
        code = "\n".join(lines)

    return code


def load_example_from_file(example_path):
    """
    Load example from a file with format:
    description_BREAK_code
    where 'code' uses \\n and \\t for formatting.
    """
    try:
        with open(example_path, "r") as f:
            content = f.read()

        parts = content.split("_BREAK_")
        if len(parts) == 2:
            description = parts[0].strip()
            code = parts[1].strip()

            code = code.replace("\\n", "\n").replace("\\t", "\t")
            code = normalize_indentation(code)

            return description, code
        else:
            print(f"Invalid format in example file: {example_path}")
            return "", ""
    except Exception as e:
        print(f"Error loading example file {example_path}: {str(e)}")
        return "", ""


def find_example_files():
    """
    Find all raw.in example files in the examples directory.
    """
    example_files = glob.glob("examples/*/raw.in")
    return example_files


# ==========================
#  Load model from HF Hub
# ==========================

BASE_MODEL_ID = "Salesforce/codet5p-770m"
FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"  
FINETUNED_FILENAME = "pytorch_model.bin"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

print(f"Loading base model: {BASE_MODEL_ID}")
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
model.to(device)

print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)

print(f"Loading state_dict from: {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location="cpu")

if "model_state_dict" in state_dict:
    state_dict = state_dict["model_state_dict"]

missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")

model.eval()




# ==========================
#  Gradio logic
# ==========================

# State variables
current_code = None
bug_counter = 0


def generate_bugged_code(description, code, chat_history, is_first_time):
    global current_code, bug_counter

    if chat_history is None:
        chat_history = []

    if is_first_time:
        bug_counter = 0
        current_code = None
        chat_history = []

    bug_counter += 1

    if bug_counter == 1:
        input_for_model = code
        input_type = "original"
    else:
        if current_code is None:
            return chat_history, gr.update(value=""), False
        input_for_model = current_code
        input_type = "previous bugged code"

    print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")

    encoded_code = encode_text(input_for_model)
    combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"

    inputs = tokenizer(
        combined_input,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    ).input_ids.to(device)

    try:
        print("Starting generation...")
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=256,
                num_beams=1,
                do_sample=False,
                early_stopping=True,
            )
        print("Generation done.")
    except Exception as e:
        print("Generation error:", repr(e))
        raise e

    bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)

    bugged_code = clear_text(bugged_code_escaped)
    bugged_code = fix_common_syntax_issues(bugged_code)
    bugged_code = format_code(bugged_code)

    current_code = bugged_code

    user_message = f"**Description**: {description}"
    if input_type == "original":
        user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
    else:
        user_message += (
            f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
        )

    ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"

    chat_history = chat_history + [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": ai_message},
    ]

    return chat_history, gr.update(value=""), False




def reset_interface():
    global current_code, bug_counter
    current_code = None
    bug_counter = 0
    return [], gr.update(value=""), True


example_files = find_example_files()
example_names = [
    f"Example {i+1}: {os.path.basename(os.path.dirname(f))}"
    for i, f in enumerate(example_files)
]


def load_example(example_index):
    if example_index < len(example_files):
        return load_example_from_file(example_files[example_index])
    return "", ""


with gr.Blocks(title="Software-Fault Injection from NL") as demo:
    gr.Markdown("# 🐞 Software-Fault Injection from Natural Language")
    gr.Markdown(
        "Generate Python code with specific bugs based on a description and original code. "
        "The model used is **BugGen (CodeT5+ 770M, PyResBugs)**."
    )

    with gr.Row():
        with gr.Column(scale=2):
            description_input = gr.Textbox(
                label="Bug Description",
                placeholder="Describe the type of bug to introduce...",
                lines=3,
            )
            code_input = gr.Code(
                label="Original Code",
                language="python",
                lines=12,
            )

            is_first = gr.State(True)

            submit_btn = gr.Button("Generate Bugged Code")
            reset_btn = gr.Button("Start Over")

            gr.Markdown("### Examples")
            example_buttons = [gr.Button(name) for name in example_names]

        with gr.Column(scale=3):
            chat_output = gr.Chatbot(
                label="Conversation",
                height=500,
            )

    for i, btn in enumerate(example_buttons):
        btn.click(
            fn=lambda i=i: load_example(i),
            outputs=[description_input, code_input],
        )

    submit_btn.click(
        fn=generate_bugged_code,
        inputs=[description_input, code_input, chat_output, is_first],
        outputs=[chat_output, description_input, is_first],
    )

    reset_btn.click(
        fn=reset_interface,
        outputs=[chat_output, description_input, is_first],
    )

print("Launching Gradio interface...")
demo.queue(max_size=10).launch()