piliguori commited on
Commit
ec5045e
·
verified ·
1 Parent(s): be20041

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -276
app.py CHANGED
@@ -1,183 +1,172 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
  import torch
4
- import os
5
- import sys
6
- import argparse
7
  import autopep8
8
  import glob
9
  import re
 
 
 
 
 
10
 
11
  def normalize_indentation(code):
12
  """
13
  Normalize indentation in example code by removing excessive tabs.
14
  Also removes any backslash characters.
15
  """
16
- # Remove backslash characters
17
- code = code.replace('\\', '')
18
-
19
- lines = code.split('\n')
20
  if not lines:
21
  return ""
22
-
23
- # Check if we have a function with excessive indentation
24
  fixed_lines = []
25
  indent_fix_mode = False
26
-
27
  for i, line in enumerate(lines):
28
- if line.strip().startswith('def '):
29
  fixed_lines.append(line)
30
  indent_fix_mode = True
31
  elif indent_fix_mode and line.strip():
32
  # For indented lines in a function
33
- if line.startswith('\t\t'): # Two tabs
34
- fixed_lines.append('\t' + line[2:]) # Replace with one tab
35
- elif line.startswith(' '): # 8 spaces (2 levels)
36
- fixed_lines.append(' ' + line[8:]) # Replace with 4 spaces
37
  else:
38
  fixed_lines.append(line)
39
  else:
40
  fixed_lines.append(line)
41
-
42
- return '\n'.join(fixed_lines)
 
43
 
44
  def clear_text(text):
45
  """
46
  Cleans text from escape sequences while preserving original formatting.
47
  """
48
- # First, replace \\n with a special placeholder
49
  temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
50
  temp_tab = "TEMP_TAB_PLACEHOLDER"
51
-
52
- # Temporarily replace escape sequences
53
  text = text.replace("\\n", temp_newline)
54
  text = text.replace("\\t", temp_tab)
55
-
56
- # Remove remaining continuation backslashes
57
  text = text.replace("\\", "")
58
-
59
- # Convert placeholders back to actual control characters
60
  text = text.replace(temp_newline, "\n")
61
  text = text.replace(temp_tab, "\t")
62
-
63
  return text
64
 
 
65
  def encode_text(text):
66
  """
67
  Encodes control characters into escape sequences.
68
  """
69
- # Replace control characters with escape sequences
70
  text = text.replace("\n", "\\n")
71
  text = text.replace("\t", "\\t")
72
-
73
  return text
74
 
 
75
  def format_code(code):
76
  """
77
  Format Python code using autopep8 with aggressive settings.
78
  """
79
  try:
80
- # More aggressive formatting with specific options
81
  formatted_code = autopep8.fix_code(
82
- code,
83
  options={
84
- 'aggressive': 2,
85
- 'max_line_length': 88,
86
- 'indent_size': 4
87
- }
88
  )
89
-
90
  # Additional formatting for consistent spacing around parentheses and operators
91
- # Remove extra spaces inside parentheses
92
  formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")
93
-
94
- # Ensure consistent spacing around operators
95
  for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
96
  formatted_code = formatted_code.replace(f"{op} ", op + " ")
97
  formatted_code = formatted_code.replace(f" {op}", " " + op)
98
-
99
- # Remove extra spaces after function calls
100
- formatted_code = re.sub(r'(\w+)\s+\(', r'\1(', formatted_code)
101
-
102
  return formatted_code
103
  except Exception as e:
104
  print(f"Error formatting code: {str(e)}")
105
  return code
106
 
 
107
  def fix_common_syntax_issues(code):
108
  """
109
  Fix common syntax issues in generated code without modifying indentation.
110
  """
111
- # Only fix critical syntax issues while preserving indentation
112
- lines = code.split('\n')
113
  fixed_lines = []
114
-
115
  for line in lines:
116
  stripped = line.strip()
117
- # Check if line needs a colon at the end
118
- if (stripped.startswith('if ') or
119
- stripped.startswith('elif ') or
120
- stripped.startswith('else') or
121
- stripped.startswith('for ') or
122
- stripped.startswith('while ') or
123
- stripped.startswith('def ') or
124
- stripped.startswith('class ')):
125
-
126
- # Only add colon if it doesn't already end with one and doesn't continue to next line
127
- if not stripped.endswith(':') and not stripped.endswith('\\'):
128
- line = line.rstrip() + ':'
129
-
130
  fixed_lines.append(line)
131
-
132
- code = '\n'.join(fixed_lines)
133
-
134
  # Fix mismatched quotes
135
  quote_chars = ['"', "'"]
136
  for quote in quote_chars:
137
- # Count quotes in the code
138
  if code.count(quote) % 2 != 0:
139
- # Find incomplete string literals
140
- lines = code.split('\n')
141
  for i, line in enumerate(lines):
142
- # Count quotes in this line
143
  if line.count(quote) % 2 != 0:
144
- # Add missing quote at the end of the line
145
  lines[i] = line.rstrip() + quote
146
  break
147
- code = '\n'.join(lines)
148
-
149
  # Fix missing parentheses in function calls
150
- pattern = r'(\w+)\s*\([^)]*$' # Function call without closing parenthesis
151
  if re.search(pattern, code):
152
- # Add closing parenthesis at the end of the line
153
- lines = code.split('\n')
154
  for i, line in enumerate(lines):
155
- if re.search(pattern, line) and not any(lines[j].strip().startswith(')') for j in range(i+1, min(i+3, len(lines)))):
156
- lines[i] = line.rstrip() + ')'
157
- code = '\n'.join(lines)
158
-
 
 
 
159
  return code
160
 
 
161
  def load_example_from_file(example_path):
162
  """
163
- Load example from a file with format description_BREAK_code
 
 
164
  """
165
  try:
166
- with open(example_path, 'r') as f:
167
  content = f.read()
168
-
169
- # Split by the separator
170
  parts = content.split("_BREAK_")
171
  if len(parts) == 2:
172
  description = parts[0].strip()
173
  code = parts[1].strip()
174
-
175
- # Replace escape sequences with actual characters
176
- code = code.replace('\\n', '\n').replace('\\t', '\t')
177
-
178
- # Fix indentation
179
  code = normalize_indentation(code)
180
-
181
  return description, code
182
  else:
183
  print(f"Invalid format in example file: {example_path}")
@@ -186,214 +175,163 @@ def load_example_from_file(example_path):
186
  print(f"Error loading example file {example_path}: {str(e)}")
187
  return "", ""
188
 
 
189
  def find_example_files():
190
  """
191
- Find all raw.in example files in the examples directory
192
  """
193
  example_files = glob.glob("examples/*/raw.in")
194
  return example_files
195
 
196
- def main():
197
- # Configure argument parser
198
- parser = argparse.ArgumentParser(description='Launch a Gradio interface for a fine-tuned CodeT5+ model')
199
- parser.add_argument('model_bin_path', type=str, help='Path to the fine-tuned model .bin file')
200
- parser.add_argument('--base_model', type=str, default="Salesforce/codet5p-770m",
201
- help='Base model name (default: Salesforce/codet5p-770m)')
202
-
203
- # Parse arguments
204
- args = parser.parse_args()
205
-
206
- # Specify the base model you fine-tuned from
207
- base_model = args.base_model
208
-
209
- print(f"Using base model: {base_model}")
210
- print(f"Loading fine-tuned weights from: {args.model_bin_path}")
211
-
212
- # Load tokenizer from the base model
213
- print("Loading tokenizer...")
214
- tokenizer = AutoTokenizer.from_pretrained(base_model)
215
-
216
- # Load the base model
217
- print("Loading base model...")
218
- model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
219
-
220
- # Load weights from the .bin file
221
- if os.path.exists(args.model_bin_path):
222
- print("Loading fine-tuned weights...")
223
- try:
224
- # Load your fine-tuned model
225
- state_dict = torch.load(args.model_bin_path, map_location=torch.device('cpu'))
226
-
227
- # If the state_dict contains a 'model_state_dict' key (common in some checkpoints)
228
- if 'model_state_dict' in state_dict:
229
- state_dict = state_dict['model_state_dict']
230
-
231
- # Load weights into the model
232
- model.load_state_dict(state_dict, strict=False)
233
- print("Fine-tuned model loaded successfully!")
234
- except Exception as e:
235
- print(f"Error loading model: {str(e)}")
236
- print("Using base model without fine-tuning.")
237
  else:
238
- print(f"WARNING: Model file not found at {args.model_bin_path}. Using base model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- # State variables to track conversation state
 
 
241
  current_code = None
242
- # Counter for tracking number of bug generation calls
243
  bug_counter = 0
 
244
 
245
- def generate_bugged_code(description, code, chat_history, is_first_time):
246
- nonlocal current_code, bug_counter
247
-
248
- # If we're starting over (is_first_time), reset the counter
249
- if is_first_time:
250
- bug_counter = 0
251
- current_code = None
252
-
253
- # Increment the counter for each call
254
- bug_counter += 1
255
-
256
- # Use counter to decide which code to use
257
- if bug_counter == 1:
258
- # First call or new code when restarting - use original code
259
- input_for_model = code
260
- input_type = "original"
261
- else:
262
- # Subsequent calls - use previously generated bugged code
263
- if current_code is None:
264
- return chat_history, gr.update(value=""), False
265
-
266
- input_for_model = current_code
267
- input_type = "previous bugged code"
268
-
269
- # Log for debugging
270
- print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")
271
-
272
- # Encode the input code
273
- encoded_code = encode_text(input_for_model)
274
- combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
275
-
276
- # Tokenize input
277
- inputs = tokenizer.encode(combined_input, return_tensors='pt')
278
-
279
- # Generate output
280
- outputs = model.generate(
281
- inputs,
282
- max_length=1024,
283
- num_beams=5,
284
- early_stopping=True
285
- )
286
-
287
- # Decode output
288
- bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
289
-
290
- # Clean code (replacing escape sequences with actual control characters)
291
- bugged_code = clear_text(bugged_code_escaped)
292
-
293
- # First fix common syntax issues without changing indentation
294
- bugged_code = fix_common_syntax_issues(bugged_code)
295
-
296
- # Apply autopep8 formatting
297
- bugged_code = format_code(bugged_code)
298
-
299
- # Update current code with the newly generated code
300
- current_code = bugged_code
301
-
302
- # Update chat history
303
- if is_first_time:
304
- chat_history = []
305
-
306
- user_message = f"**Description**: {description}"
307
- if input_type == "original":
308
- user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
309
- else:
310
- user_message += f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
311
-
312
- ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"
313
-
314
- chat_history.append((user_message, ai_message))
315
-
316
- return chat_history, gr.update(value=""), False
317
-
318
- # Function to handle interface reset
319
- def reset_interface():
320
- nonlocal current_code, bug_counter
321
- current_code = None
322
- bug_counter = 0
323
- return [], gr.update(value=""), True
324
 
325
- # Find example files
326
- example_files = find_example_files()
327
- example_names = [f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files)]
 
 
 
 
 
 
 
 
328
 
329
- # Function to load examples
330
- def load_example(example_index):
331
- if example_index < len(example_files):
332
- return load_example_from_file(example_files[example_index])
333
- return "", ""
334
 
335
- # Create Gradio interface using Blocks
336
- with gr.Blocks(title="Software-Fault Injection from NL") as demo:
337
- gr.Markdown("# Software-Fault Injection from Natural Language")
338
- gr.Markdown("Generate Python code with specific bugs based on a description and original code.")
339
-
340
- with gr.Row():
341
- with gr.Column(scale=2):
342
- # Input components
343
- description_input = gr.Textbox(
344
- label="Bug Description",
345
- placeholder="Describe the type of bug to introduce...",
346
- lines=3
347
- )
348
- code_input = gr.Code(
349
- label="Original Code",
350
- language="python",
351
- lines=10
352
- )
353
-
354
- # Flag to track if it's the first submission
355
- is_first = gr.State(True)
356
-
357
- # Submit button
358
- submit_btn = gr.Button("Generate Bugged Code")
359
-
360
- # Reset button
361
- reset_btn = gr.Button("Start Over")
362
-
363
- # Examples
364
- gr.Markdown("### Examples")
365
- example_buttons = [gr.Button(name) for name in example_names]
366
-
367
- with gr.Column(scale=3):
368
- # Chat history
369
- chat_output = gr.Chatbot(
370
- label="Conversation",
371
- height=500
372
- )
373
-
374
- # Connect example buttons
375
- for i, btn in enumerate(example_buttons):
376
- btn.click(
377
- fn=lambda i=i: load_example(i),
378
- outputs=[description_input, code_input]
379
  )
380
 
381
- # Connect submit button
382
- submit_btn.click(
383
- fn=generate_bugged_code,
384
- inputs=[description_input, code_input, chat_output, is_first],
385
- outputs=[chat_output, description_input, is_first]
386
- )
387
-
388
- # Connect reset button
389
- reset_btn.click(
390
- fn=reset_interface,
391
- outputs=[chat_output, description_input, is_first]
 
 
 
 
 
 
 
392
  )
393
 
394
- # Launch interface
395
- print("Launching Gradio interface...")
396
- demo.launch()
 
 
 
 
 
 
 
397
 
398
- if __name__ == "__main__":
399
- main()
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
  import torch
 
 
 
4
  import autopep8
5
  import glob
6
  import re
7
+ import os
8
+
9
+ # ==========================
10
+ # Utility functions
11
+ # ==========================
12
 
13
  def normalize_indentation(code):
14
  """
15
  Normalize indentation in example code by removing excessive tabs.
16
  Also removes any backslash characters.
17
  """
18
+ code = code.replace("\\", "")
19
+
20
+ lines = code.split("\n")
 
21
  if not lines:
22
  return ""
23
+
 
24
  fixed_lines = []
25
  indent_fix_mode = False
26
+
27
  for i, line in enumerate(lines):
28
+ if line.strip().startswith("def "):
29
  fixed_lines.append(line)
30
  indent_fix_mode = True
31
  elif indent_fix_mode and line.strip():
32
  # For indented lines in a function
33
+ if line.startswith("\t\t"): # Two tabs
34
+ fixed_lines.append("\t" + line[2:]) # Replace with one tab
35
+ elif line.startswith(" "): # 8 spaces (2 levels)
36
+ fixed_lines.append(" " + line[8:]) # Replace with 4 spaces
37
  else:
38
  fixed_lines.append(line)
39
  else:
40
  fixed_lines.append(line)
41
+
42
+ return "\n".join(fixed_lines)
43
+
44
 
45
  def clear_text(text):
46
  """
47
  Cleans text from escape sequences while preserving original formatting.
48
  """
 
49
  temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
50
  temp_tab = "TEMP_TAB_PLACEHOLDER"
51
+
 
52
  text = text.replace("\\n", temp_newline)
53
  text = text.replace("\\t", temp_tab)
54
+
 
55
  text = text.replace("\\", "")
56
+
 
57
  text = text.replace(temp_newline, "\n")
58
  text = text.replace(temp_tab, "\t")
59
+
60
  return text
61
 
62
+
63
  def encode_text(text):
64
  """
65
  Encodes control characters into escape sequences.
66
  """
 
67
  text = text.replace("\n", "\\n")
68
  text = text.replace("\t", "\\t")
 
69
  return text
70
 
71
+
72
  def format_code(code):
73
  """
74
  Format Python code using autopep8 with aggressive settings.
75
  """
76
  try:
 
77
  formatted_code = autopep8.fix_code(
78
+ code,
79
  options={
80
+ "aggressive": 2,
81
+ "max_line_length": 88,
82
+ "indent_size": 4,
83
+ },
84
  )
85
+
86
  # Additional formatting for consistent spacing around parentheses and operators
 
87
  formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")
88
+
 
89
  for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
90
  formatted_code = formatted_code.replace(f"{op} ", op + " ")
91
  formatted_code = formatted_code.replace(f" {op}", " " + op)
92
+
93
+ formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code)
94
+
 
95
  return formatted_code
96
  except Exception as e:
97
  print(f"Error formatting code: {str(e)}")
98
  return code
99
 
100
+
101
  def fix_common_syntax_issues(code):
102
  """
103
  Fix common syntax issues in generated code without modifying indentation.
104
  """
105
+ lines = code.split("\n")
 
106
  fixed_lines = []
107
+
108
  for line in lines:
109
  stripped = line.strip()
110
+ if (
111
+ stripped.startswith("if ")
112
+ or stripped.startswith("elif ")
113
+ or stripped.startswith("else")
114
+ or stripped.startswith("for ")
115
+ or stripped.startswith("while ")
116
+ or stripped.startswith("def ")
117
+ or stripped.startswith("class ")
118
+ ):
119
+ if not stripped.endswith(":") and not stripped.endswith("\\"):
120
+ line = line.rstrip() + ":"
121
+
 
122
  fixed_lines.append(line)
123
+
124
+ code = "\n".join(fixed_lines)
125
+
126
  # Fix mismatched quotes
127
  quote_chars = ['"', "'"]
128
  for quote in quote_chars:
 
129
  if code.count(quote) % 2 != 0:
130
+ lines = code.split("\n")
 
131
  for i, line in enumerate(lines):
 
132
  if line.count(quote) % 2 != 0:
 
133
  lines[i] = line.rstrip() + quote
134
  break
135
+ code = "\n".join(lines)
136
+
137
  # Fix missing parentheses in function calls
138
+ pattern = r"(\w+)\s*\([^)]*$"
139
  if re.search(pattern, code):
140
+ lines = code.split("\n")
 
141
  for i, line in enumerate(lines):
142
+ if re.search(pattern, line) and not any(
143
+ lines[j].strip().startswith(")")
144
+ for j in range(i + 1, min(i + 3, len(lines)))
145
+ ):
146
+ lines[i] = line.rstrip() + ")"
147
+ code = "\n".join(lines)
148
+
149
  return code
150
 
151
+
152
  def load_example_from_file(example_path):
153
  """
154
+ Load example from a file with format:
155
+ description_BREAK_code
156
+ where 'code' uses \\n and \\t for formatting.
157
  """
158
  try:
159
+ with open(example_path, "r") as f:
160
  content = f.read()
161
+
 
162
  parts = content.split("_BREAK_")
163
  if len(parts) == 2:
164
  description = parts[0].strip()
165
  code = parts[1].strip()
166
+
167
+ code = code.replace("\\n", "\n").replace("\\t", "\t")
 
 
 
168
  code = normalize_indentation(code)
169
+
170
  return description, code
171
  else:
172
  print(f"Invalid format in example file: {example_path}")
 
175
  print(f"Error loading example file {example_path}: {str(e)}")
176
  return "", ""
177
 
178
+
179
  def find_example_files():
180
  """
181
+ Find all raw.in example files in the examples directory.
182
  """
183
  example_files = glob.glob("examples/*/raw.in")
184
  return example_files
185
 
186
+
187
+ # ==========================
188
+ # Load model from HF Hub
189
+ # ==========================
190
+
191
+ MODEL_ID = "OSS-Forge/buggen-codet5p-770m-pyresbugs" # <-- cambia se usi un nome diverso
192
+
193
+ print(f"Loading tokenizer and model from: {MODEL_ID}")
194
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
195
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
196
+
197
+
198
+ # ==========================
199
+ # Gradio logic
200
+ # ==========================
201
+
202
+ # State variables
203
+ current_code = None
204
+ bug_counter = 0
205
+
206
+
207
+ def generate_bugged_code(description, code, chat_history, is_first_time):
208
+ global current_code, bug_counter
209
+
210
+ if is_first_time:
211
+ bug_counter = 0
212
+ current_code = None
213
+
214
+ bug_counter += 1
215
+
216
+ if bug_counter == 1:
217
+ input_for_model = code
218
+ input_type = "original"
 
 
 
 
 
 
 
 
219
  else:
220
+ if current_code is None:
221
+ return chat_history, gr.update(value=""), False
222
+ input_for_model = current_code
223
+ input_type = "previous bugged code"
224
+
225
+ print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")
226
+
227
+ encoded_code = encode_text(input_for_model)
228
+ combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
229
+
230
+ inputs = tokenizer.encode(combined_input, return_tensors="pt")
231
+
232
+ outputs = model.generate(
233
+ inputs,
234
+ max_length=1024,
235
+ num_beams=5,
236
+ early_stopping=True,
237
+ )
238
+
239
+ bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
240
+
241
+ bugged_code = clear_text(bugged_code_escaped)
242
+ bugged_code = fix_common_syntax_issues(bugged_code)
243
+ bugged_code = format_code(bugged_code)
244
+
245
+ current_code = bugged_code
246
+
247
+ if is_first_time:
248
+ chat_history = []
249
+
250
+ user_message = f"**Description**: {description}"
251
+ if input_type == "original":
252
+ user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
253
+ else:
254
+ user_message += (
255
+ f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
256
+ )
257
+
258
+ ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"
259
+
260
+ chat_history.append((user_message, ai_message))
261
+
262
+ return chat_history, gr.update(value=""), False
263
 
264
+
265
+ def reset_interface():
266
+ global current_code, bug_counter
267
  current_code = None
 
268
  bug_counter = 0
269
+ return [], gr.update(value=""), True
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ example_files = find_example_files()
273
+ example_names = [
274
+ f"Example {i+1}: {os.path.basename(os.path.dirname(f))}"
275
+ for i, f in enumerate(example_files)
276
+ ]
277
+
278
+
279
+ def load_example(example_index):
280
+ if example_index < len(example_files):
281
+ return load_example_from_file(example_files[example_index])
282
+ return "", ""
283
 
 
 
 
 
 
284
 
285
+ with gr.Blocks(title="Software-Fault Injection from NL") as demo:
286
+ gr.Markdown("# 🐞 Software-Fault Injection from Natural Language")
287
+ gr.Markdown(
288
+ "Generate Python code with specific bugs based on a description and original code. "
289
+ "The model used is **BugGen (CodeT5+ 770M, PyResBugs)**."
290
+ )
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=2):
294
+ description_input = gr.Textbox(
295
+ label="Bug Description",
296
+ placeholder="Describe the type of bug to introduce...",
297
+ lines=3,
298
+ )
299
+ code_input = gr.Code(
300
+ label="Original Code",
301
+ language="python",
302
+ lines=12,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  )
304
 
305
+ is_first = gr.State(True)
306
+
307
+ submit_btn = gr.Button("Generate Bugged Code")
308
+ reset_btn = gr.Button("Start Over")
309
+
310
+ gr.Markdown("### Examples")
311
+ example_buttons = [gr.Button(name) for name in example_names]
312
+
313
+ with gr.Column(scale=3):
314
+ chat_output = gr.Chatbot(
315
+ label="Conversation",
316
+ height=500,
317
+ )
318
+
319
+ for i, btn in enumerate(example_buttons):
320
+ btn.click(
321
+ fn=lambda i=i: load_example(i),
322
+ outputs=[description_input, code_input],
323
  )
324
 
325
+ submit_btn.click(
326
+ fn=generate_bugged_code,
327
+ inputs=[description_input, code_input, chat_output, is_first],
328
+ outputs=[chat_output, description_input, is_first],
329
+ )
330
+
331
+ reset_btn.click(
332
+ fn=reset_interface,
333
+ outputs=[chat_output, description_input, is_first],
334
+ )
335
 
336
+ print("Launching Gradio interface...")
337
+ demo.launch()