Update README.md
Browse files
README.md
CHANGED
|
@@ -22,11 +22,15 @@ To load and use this model with Hugging Face Transformers:
|
|
| 22 |
import torch
|
| 23 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 24 |
|
| 25 |
-
model_name = "maomaocun/LLaDA-Prometheus
|
| 26 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 32 |
input_ids = inputs['input_ids']
|
|
@@ -38,10 +42,10 @@ for chunk in model.generate(
|
|
| 38 |
block_length=64,
|
| 39 |
threshold=0.9,
|
| 40 |
streaming=True,
|
| 41 |
-
eos_token_id=
|
| 42 |
):
|
| 43 |
all_generated_ids = torch.cat([input_ids, chunk], dim=-1)
|
| 44 |
-
text = tokenizer.batch_decode(all_generated_ids, skip_special_tokens=
|
| 45 |
print(text, end='', flush=True)
|
| 46 |
```
|
| 47 |
|
|
|
|
| 22 |
import torch
|
| 23 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 24 |
|
| 25 |
+
model_name = "maomaocun/LLaDA-Prometheus"
|
| 26 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
|
| 28 |
|
| 29 |
+
# 使用对话模板
|
| 30 |
+
messages = [
|
| 31 |
+
{"role": "user", "content": "Can you tell me an engaging short story about a brave young astronaut who discovers an ancient alien civilization on a distant planet? Make it adventurous and heartwarming, with a twist at the end."}
|
| 32 |
+
]
|
| 33 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 34 |
|
| 35 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 36 |
input_ids = inputs['input_ids']
|
|
|
|
| 42 |
block_length=64,
|
| 43 |
threshold=0.9,
|
| 44 |
streaming=True,
|
| 45 |
+
eos_token_id=126348 # 注意:这里改为 eos_token_id
|
| 46 |
):
|
| 47 |
all_generated_ids = torch.cat([input_ids, chunk], dim=-1)
|
| 48 |
+
text = tokenizer.batch_decode(all_generated_ids, skip_special_tokens=True)[0] # 改为 skip_special_tokens=True,并移除手动 split
|
| 49 |
print(text, end='', flush=True)
|
| 50 |
```
|
| 51 |
|