piliguori commited on
Commit
9540433
·
verified ·
1 Parent(s): fbf65ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -252,14 +252,27 @@ def generate_bugged_code(description, code, chat_history, is_first_time):
252
  encoded_code = encode_text(input_for_model)
253
  combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
254
 
255
- inputs = tokenizer.encode(combined_input, return_tensors="pt")
 
 
 
 
 
256
 
257
- outputs = model.generate(
258
- inputs,
259
- max_length=1024,
260
- num_beams=5,
261
- early_stopping=True,
262
- )
 
 
 
 
 
 
 
 
263
 
264
  bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
265
 
@@ -287,6 +300,7 @@ def generate_bugged_code(description, code, chat_history, is_first_time):
287
  return chat_history, gr.update(value=""), False
288
 
289
 
 
290
  def reset_interface():
291
  global current_code, bug_counter
292
  current_code = None
@@ -359,4 +373,4 @@ with gr.Blocks(title="Software-Fault Injection from NL") as demo:
359
  )
360
 
361
  print("Launching Gradio interface...")
362
- demo.launch()
 
252
  encoded_code = encode_text(input_for_model)
253
  combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
254
 
255
+ inputs = tokenizer(
256
+ combined_input,
257
+ return_tensors="pt",
258
+ truncation=True,
259
+ max_length=512,
260
+ ).input_ids.to(device)
261
 
262
+ try:
263
+ print("Starting generation...")
264
+ with torch.no_grad():
265
+ outputs = model.generate(
266
+ inputs,
267
+ max_new_tokens=256,
268
+ num_beams=1,
269
+ do_sample=False,
270
+ early_stopping=True,
271
+ )
272
+ print("Generation done.")
273
+ except Exception as e:
274
+ print("Generation error:", repr(e))
275
+ raise e
276
 
277
  bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
278
 
 
300
  return chat_history, gr.update(value=""), False
301
 
302
 
303
+
304
  def reset_interface():
305
  global current_code, bug_counter
306
  current_code = None
 
373
  )
374
 
375
  print("Launching Gradio interface...")
376
+ demo.queue(max_size=10, concurrency_count=1).launch()