File size: 1,274 Bytes
6aeeb9a ef5af97 6aeeb9a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | #!/usr/bin/env python3
"""
Example: generate text from QED-75M on Hugging Face.
Run:
python generate_gravity_example.py
"""
from __future__ import annotations
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def main() -> None:
repo_id = "levossadtchi/QED-75M"
prompt = "Explain gravity in one sentence. \n<|assistant|>"
# trust_remote_code=True is required because QED is a custom architecture.
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float32,
)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)
with torch.no_grad():
out_ids = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.8,
top_k=50,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
print(text)
if __name__ == "__main__":
main()
|