Fix dtype processing in README example
#5
by
orrzohar
- opened
README.md
CHANGED
|
@@ -173,7 +173,7 @@ inputs = processor.apply_chat_template(
|
|
| 173 |
tokenize=True,
|
| 174 |
return_dict=True,
|
| 175 |
return_tensors="pt",
|
| 176 |
-
).to(model.device)
|
| 177 |
|
| 178 |
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
|
| 179 |
generated_texts = processor.batch_decode(
|
|
|
|
| 173 |
tokenize=True,
|
| 174 |
return_dict=True,
|
| 175 |
return_tensors="pt",
|
| 176 |
+
).to(model.device, dtype=torch.bfloat16)
|
| 177 |
|
| 178 |
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=64)
|
| 179 |
generated_texts = processor.batch_decode(
|