maomaocun commited on
Commit
b1fecec
·
verified ·
1 Parent(s): 79c3655

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -4
README.md CHANGED
@@ -22,11 +22,15 @@ To load and use this model with Hugging Face Transformers:
22
  import torch
23
  from transformers import AutoTokenizer, AutoModelForCausalLM
24
 
25
- model_name = "maomaocun/LLaDA-Prometheus-no-template"
26
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
28
 
29
- prompt = "Can you tell me an engaging short story about a brave young astronaut who discovers an ancient alien civilization on a distant planet? Make it adventurous and heartwarming, with a twist at the end."
 
 
 
 
30
 
31
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
32
  input_ids = inputs['input_ids']
@@ -38,10 +42,10 @@ for chunk in model.generate(
38
  block_length=64,
39
  threshold=0.9,
40
  streaming=True,
41
- eos_token_id=tokenizer.eos_token,
42
  ):
43
  all_generated_ids = torch.cat([input_ids, chunk], dim=-1)
44
- text = tokenizer.batch_decode(all_generated_ids, skip_special_tokens=False)[0].split(tokenizer.eos_token)[0]
45
  print(text, end='', flush=True)
46
  ```
47
 
 
22
  import torch
23
  from transformers import AutoTokenizer, AutoModelForCausalLM
24
 
25
+ model_name = "maomaocun/LLaDA-Prometheus"
26
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
28
 
29
+ # 使用对话模板
30
+ messages = [
31
+ {"role": "user", "content": "Can you tell me an engaging short story about a brave young astronaut who discovers an ancient alien civilization on a distant planet? Make it adventurous and heartwarming, with a twist at the end."}
32
+ ]
33
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
 
35
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
36
  input_ids = inputs['input_ids']
 
42
  block_length=64,
43
  threshold=0.9,
44
  streaming=True,
45
+ eos_token_id=126348 # 注意:这里改为 eos_token_id
46
  ):
47
  all_generated_ids = torch.cat([input_ids, chunk], dim=-1)
48
+ text = tokenizer.batch_decode(all_generated_ids, skip_special_tokens=True)[0] # 改为 skip_special_tokens=True,并移除手动 split
49
  print(text, end='', flush=True)
50
  ```
51