Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import clueai
|
|
| 4 |
import torch
|
| 5 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 6 |
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
|
| 7 |
-
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
|
| 8 |
# 使用
|
| 9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 10 |
model.to(device)
|
|
@@ -25,11 +25,11 @@ def answer(text, sample=True, top_p=0.9, temperature=0.7):
|
|
| 25 |
top_p:0-1之间,生成的内容越多样'''
|
| 26 |
text = preprocess(text)
|
| 27 |
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
out=model.generate(**encoding, **generate_config)
|
| 33 |
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
|
| 34 |
return postprocess(out_text[0])
|
| 35 |
|
|
|
|
| 4 |
import torch
|
| 5 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 6 |
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
|
| 7 |
+
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
|
| 8 |
# 使用
|
| 9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 10 |
model.to(device)
|
|
|
|
| 25 |
top_p:0-1之间,生成的内容越多样'''
|
| 26 |
text = preprocess(text)
|
| 27 |
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
|
| 28 |
+
if not sample:
|
| 29 |
+
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
|
| 30 |
+
else:
|
| 31 |
+
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
|
| 32 |
+
#out=model.generate(**encoding, **generate_config)
|
| 33 |
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
|
| 34 |
return postprocess(out_text[0])
|
| 35 |
|