atiwari751 commited on
Commit
c4bbf27
·
1 Parent(s): e773303

fixed CUDA issue for HF spaces

Browse files
Files changed (2) hide show
  1. app.py +24 -20
  2. generate.py +1 -1
app.py CHANGED
@@ -1,26 +1,30 @@
1
  import gradio as gr
2
  from generate import generate_text
 
3
 
4
  def generate(prompt):
5
- # Redirect print output to capture the generated text
6
- import io
7
- import sys
8
- old_stdout = sys.stdout
9
- new_stdout = io.StringIO()
10
- sys.stdout = new_stdout
11
-
12
- # Generate text with default parameters
13
- generate_text(
14
- prompt=prompt,
15
- max_length=100, # default value
16
- num_sequences=5 # default value
17
- )
18
-
19
- # Get the output and restore stdout
20
- output = new_stdout.getvalue()
21
- sys.stdout = old_stdout
22
-
23
- return output
 
 
 
24
 
25
  # Create the Gradio interface
26
  demo = gr.Interface(
@@ -32,4 +36,4 @@ demo = gr.Interface(
32
  )
33
 
34
  if __name__ == "__main__":
35
- demo.launch(share=False) # set share=True if you want to generate a public URL
 
1
  import gradio as gr
2
  from generate import generate_text
3
+ import torch
4
 
5
  def generate(prompt):
6
+ try:
7
+ # Redirect print output to capture the generated text
8
+ import io
9
+ import sys
10
+ old_stdout = sys.stdout
11
+ new_stdout = io.StringIO()
12
+ sys.stdout = new_stdout
13
+
14
+ # Generate text with default parameters
15
+ generate_text(
16
+ prompt=prompt,
17
+ max_length=100, # default value
18
+ num_sequences=5 # default value
19
+ )
20
+
21
+ # Get the output and restore stdout
22
+ output = new_stdout.getvalue()
23
+ sys.stdout = old_stdout
24
+
25
+ return output
26
+ except Exception as e:
27
+ return f"An error occurred: {str(e)}"
28
 
29
  # Create the Gradio interface
30
  demo = gr.Interface(
 
36
  )
37
 
38
  if __name__ == "__main__":
39
+ demo.launch()
generate.py CHANGED
@@ -9,7 +9,7 @@ print(f"Using device: {device}")
9
 
10
  # Initialize model and load trained weights
11
  model = GPT(GPTConfig())
12
- model.load_state_dict(torch.load('best_model.pt'))
13
  model.to(device)
14
  model.eval() # Set to evaluation mode
15
 
 
9
 
10
  # Initialize model and load trained weights
11
  model = GPT(GPTConfig())
12
+ model.load_state_dict(torch.load('best_model.pt', map_location=torch.device(device)))
13
  model.to(device)
14
  model.eval() # Set to evaluation mode
15