| from transformers import AutoConfig, AutoModel, AutoTokenizer |
| import torch |
|
|
| model_dir = "./d20/nanochat_d20base" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = AutoModel.from_pretrained(model_dir, trust_remote_code=True) |
| model = model.to(device) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
|
| prompt = "The capital of Belgium is " |
| input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id()) |
| ids = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
| max_new_tokens = 50 |
| with torch.inference_mode(): |
| for _ in range(max_new_tokens): |
| outputs = model(input_ids=ids) |
| logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits |
| next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) |
| ids = torch.cat([ids, next_token], dim=1) |
|
|
| decoded = tokenizer.decode(ids[0].tolist()) |
| print(decoded) |
|
|