Text Generation
Transformers
Safetensors
English
olmo3
conversational
mario-sanz commited on
Commit
4afa9e9
·
verified ·
1 Parent(s): b3744c9

Update inference examples to use the correct chat template

Browse files

Hey there! 👋

I noticed that the current Python examples for `transformers` and `vllm` aren't using the chat template. It seems like these examples might have been intended for the base model, but since this is the Think version, skipping the specific formatting causes the model to generate unexpected or low-quality outputs.

I’ve updated the code snippets to use `apply_chat_template` so the prompts are formatted exactly how the model expects (handling the `<|im_start|>` and `<|think|>` tokens automatically). This should make the examples work much smoother for new users!

Thanks for releasing the model! 🚀

Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -46,13 +46,13 @@ You can use OLMo with the standard HuggingFace transformers library:
46
  from transformers import AutoModelForCausalLM, AutoTokenizer
47
  olmo = AutoModelForCausalLM.from_pretrained("allenai/Olmo-3-7B-Think")
48
  tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Think")
49
- message = ["Who would win in a fight - a dinosaur or a cow named Moo Moo?"]
50
- inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
51
  # optional verifying cuda
52
  # inputs = {k: v.to('cuda') for k,v in inputs.items()}
53
  # olmo = olmo.to('cuda')
54
  response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
55
- print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
56
  >> '<think>Okay, so the question is who would win in a fight...'
57
  ```
58
 
@@ -182,8 +182,8 @@ model = AutoModelForCausalLM.from_pretrained(
182
  device_map="auto",
183
  )
184
 
185
- prompt = "Who would win in a fight - a dinosaur or a cow named MooMoo?"
186
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
187
 
188
  outputs = model.generate(
189
  **inputs,
@@ -192,7 +192,7 @@ outputs = model.generate(
192
  max_new_tokens=32768,
193
  )
194
 
195
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
196
  ```
197
 
198
  ### vllm Example
@@ -208,8 +208,8 @@ sampling_params = SamplingParams(
208
  max_tokens=32768,
209
  )
210
 
211
- prompt = "Who would win in a fight - a dinosaur or a cow named MooMoo?"
212
- outputs = llm.generate(prompt, sampling_params)
213
  print(outputs[0].outputs[0].text)
214
  ```
215
 
 
46
  from transformers import AutoModelForCausalLM, AutoTokenizer
47
  olmo = AutoModelForCausalLM.from_pretrained("allenai/Olmo-3-7B-Think")
48
  tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Think")
49
+ message = [{"role": "user", "content": "Who would win in a fight - a dinosaur or a cow named Moo Moo?"}]
50
+ inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors='pt', return_dict=True)
51
  # optional verifying cuda
52
  # inputs = {k: v.to('cuda') for k,v in inputs.items()}
53
  # olmo = olmo.to('cuda')
54
  response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
55
+ print(tokenizer.decode(response[0][inputs.input_ids.shape[1]:], skip_special_tokens=True))
56
  >> '<think>Okay, so the question is who would win in a fight...'
57
  ```
58
 
 
182
  device_map="auto",
183
  )
184
 
185
+ message = [{"role": "user", "content": "Who would win in a fight - a dinosaur or a cow named Moo Moo?"}]
186
+ inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors='pt', return_dict=True).to(model.device)
187
 
188
  outputs = model.generate(
189
  **inputs,
 
192
  max_new_tokens=32768,
193
  )
194
 
195
+ print(tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True))
196
  ```
197
 
198
  ### vllm Example
 
208
  max_tokens=32768,
209
  )
210
 
211
+ message = [{"role": "user", "content": "Who would win in a fight - a dinosaur or a cow named Moo Moo?"}]
212
+ outputs = llm.chat(message, sampling_params)
213
  print(outputs[0].outputs[0].text)
214
  ```
215