ivmpfa commited on
Commit
303f892
·
verified ·
1 Parent(s): 28ba83b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -1,20 +1,22 @@
 
1
  from transformers import pipeline
2
  import torch
3
 
4
- # Load GPT-2 with float32 (remove torch_dtype=torch.float16)
5
  model = pipeline(
6
  "text-generation",
7
  model="gpt2",
8
  max_length=200,
9
  temperature=0.7,
10
  early_stopping=True,
11
- torch_dtype=torch.float32 # << Use float32 instead of float16
12
  )
13
 
14
  def generate_test_cases(requirement):
15
  prompt = f"""
16
  Generate test cases for '{requirement}' in JSON format. Output only the array.
17
- Example:
 
18
  [
19
  {{
20
  "id": 1,
@@ -30,6 +32,7 @@ def generate_test_cases(requirement):
30
  }}
31
  ]
32
  """
 
33
  try:
34
  with torch.no_grad():
35
  result = model(prompt, max_time=10)[0]["generated_text"]
@@ -37,6 +40,7 @@ def generate_test_cases(requirement):
37
  except Exception as e:
38
  return f"Error: {str(e)}"
39
 
 
40
  demo = gr.Interface(
41
  fn=generate_test_cases,
42
  inputs="text",
@@ -46,4 +50,11 @@ demo = gr.Interface(
46
  flagging_mode="never"
47
  )
48
 
49
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
 
5
+ # Load GPT-2 with optimized parameters
6
  model = pipeline(
7
  "text-generation",
8
  model="gpt2",
9
  max_length=200,
10
  temperature=0.7,
11
  early_stopping=True,
12
+ torch_dtype=torch.float32 # Use float32 for CPU compatibility
13
  )
14
 
15
  def generate_test_cases(requirement):
16
  prompt = f"""
17
  Generate test cases for '{requirement}' in JSON format. Output only the array.
18
+
19
+ Example format:
20
  [
21
  {{
22
  "id": 1,
 
32
  }}
33
  ]
34
  """
35
+
36
  try:
37
  with torch.no_grad():
38
  result = model(prompt, max_time=10)[0]["generated_text"]
 
40
  except Exception as e:
41
  return f"Error: {str(e)}"
42
 
43
+ # Create Gradio interface
44
  demo = gr.Interface(
45
  fn=generate_test_cases,
46
  inputs="text",
 
50
  flagging_mode="never"
51
  )
52
 
53
+ # Launch the app
54
+ if __name__ == "__main__":
55
+ demo.launch(
56
+ server_name="0.0.0.0",
57
+ server_port=7860,
58
+ debug=True,
59
+ prevent_thread_lock=True
60
+ )