File size: 591 Bytes
a0d6ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from src.model import GTransformerConfig, GTransformerForCausalLM
from safetensors.torch import load_file

# Load konfigurasi dan model
config = GTransformerConfig()
model = GTransformerForCausalLM(config)
model.load_state_dict(load_file("pytorch_model.safetensors"))
model.eval()

# Token input contoh
input_ids = torch.tensor([[1, 11, 12, 2]])  # <s> information energy </s>

# Inferensi sederhana
with torch.no_grad():
    outputs = model(input_ids)
    print("✅ Output logits shape:", outputs.logits.shape)
    print("Token terakhir prediksi:", outputs.logits[0, -1, :5])