nanochat_d20base / base_nano.py
Bajju360's picture
Upload folder using huggingface_hub
3728678 verified
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
model_dir = "./d20/nanochat_d20base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
model = model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
prompt = "The capital of Belgium is "
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
ids = torch.tensor([input_ids], dtype=torch.long, device=device)
max_new_tokens = 50
with torch.inference_mode():
for _ in range(max_new_tokens):
outputs = model(input_ids=ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
ids = torch.cat([ids, next_token], dim=1)
decoded = tokenizer.decode(ids[0].tolist())
print(decoded)