ivmpfa commited on
Commit
28ba83b
·
verified ·
1 Parent(s): 0c64484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -18
app.py CHANGED
@@ -1,23 +1,20 @@
1
- import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
 
5
- # Load GPT-2 with optimized parameters
6
  model = pipeline(
7
  "text-generation",
8
  model="gpt2",
9
  max_length=200,
10
  temperature=0.7,
11
  early_stopping=True,
12
- torch_dtype=torch.float16 # Use half-precision for speed
13
  )
14
 
15
  def generate_test_cases(requirement):
16
- # Simplified prompt with clear instructions and example
17
  prompt = f"""
18
  Generate test cases for '{requirement}' in JSON format. Output only the array.
19
-
20
- Example format:
21
  [
22
  {{
23
  "id": 1,
@@ -33,30 +30,20 @@ def generate_test_cases(requirement):
33
  }}
34
  ]
35
  """
36
-
37
  try:
38
- # Generate text with a timeout (avoids long inference)
39
  with torch.no_grad():
40
  result = model(prompt, max_time=10)[0]["generated_text"]
41
  return result.strip()
42
  except Exception as e:
43
  return f"Error: {str(e)}"
44
 
45
- # Create Gradio interface with proper flagging_mode
46
  demo = gr.Interface(
47
  fn=generate_test_cases,
48
  inputs="text",
49
  outputs="text",
50
  title="Test Case Generator",
51
  description="Enter a requirement to generate test cases.",
52
- flagging_mode="never" # Valid value
53
  )
54
 
55
- # Launch with proper settings
56
- if __name__ == "__main__":
57
- demo.launch(
58
- server_name="0.0.0.0",
59
- server_port=7860,
60
- debug=True,
61
- prevent_thread_lock=True
62
- )
 
 
1
  from transformers import pipeline
2
  import torch
3
 
4
+ # Load GPT-2 with float32 (remove torch_dtype=torch.float16)
5
  model = pipeline(
6
  "text-generation",
7
  model="gpt2",
8
  max_length=200,
9
  temperature=0.7,
10
  early_stopping=True,
11
+ torch_dtype=torch.float32 # << Use float32 instead of float16
12
  )
13
 
14
  def generate_test_cases(requirement):
 
15
  prompt = f"""
16
  Generate test cases for '{requirement}' in JSON format. Output only the array.
17
+ Example:
 
18
  [
19
  {{
20
  "id": 1,
 
30
  }}
31
  ]
32
  """
 
33
  try:
 
34
  with torch.no_grad():
35
  result = model(prompt, max_time=10)[0]["generated_text"]
36
  return result.strip()
37
  except Exception as e:
38
  return f"Error: {str(e)}"
39
 
 
40
  demo = gr.Interface(
41
  fn=generate_test_cases,
42
  inputs="text",
43
  outputs="text",
44
  title="Test Case Generator",
45
  description="Enter a requirement to generate test cases.",
46
+ flagging_mode="never"
47
  )
48
 
49
+ demo.launch(server_name="0.0.0.0", server_port=7860)