abstractmachine commited on
Commit
3ce3629
·
1 Parent(s): ae640a5

Changed dtype in torch to float32

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -16,16 +16,15 @@ print("tokenizer: " + tokenizer.name_or_path)
16
  pipeline = transformers.pipeline(
17
  "text-generation",
18
  model=model,
19
- torch_dtype=torch.float16,
 
20
  device_map="auto",
21
  )
22
 
23
  print("pipeline: " + pipeline.model.name_or_path)
24
 
25
  def generate(prompt):
26
-
27
  output = ""
28
-
29
  sequences = pipeline(
30
  prompt,
31
  do_sample=True,
@@ -35,11 +34,7 @@ def generate(prompt):
35
  eos_token_id=tokenizer.eos_token_id,
36
  max_length=1000,
37
  )
38
-
39
- # for seq in sequences:
40
- # output += seq['generated_text']
41
-
42
- return output
43
 
44
  iface = gr.Interface(fn=generate, inputs="text", outputs="text")
45
  iface.launch()
 
16
  pipeline = transformers.pipeline(
17
  "text-generation",
18
  model=model,
19
+ #torch_dtype=torch.float16,
20
+ torch_dtype=torch.float32,
21
  device_map="auto",
22
  )
23
 
24
  print("pipeline: " + pipeline.model.name_or_path)
25
 
26
  def generate(prompt):
 
27
  output = ""
 
28
  sequences = pipeline(
29
  prompt,
30
  do_sample=True,
 
34
  eos_token_id=tokenizer.eos_token_id,
35
  max_length=1000,
36
  )
37
+ return sequences[0]['generated_text']
 
 
 
 
38
 
39
  iface = gr.Interface(fn=generate, inputs="text", outputs="text")
40
  iface.launch()