G-Transformer / test_inference.py
Syamsuddin's picture
Upload 14 files
a0d6ae6 verified
raw
history blame contribute delete
591 Bytes
import torch
from src.model import GTransformerConfig, GTransformerForCausalLM
from safetensors.torch import load_file
# Load konfigurasi dan model
config = GTransformerConfig()
model = GTransformerForCausalLM(config)
model.load_state_dict(load_file("pytorch_model.safetensors"))
model.eval()
# Token input contoh
input_ids = torch.tensor([[1, 11, 12, 2]]) # <s> information energy </s>
# Inferensi sederhana
with torch.no_grad():
outputs = model(input_ids)
print("✅ Output logits shape:", outputs.logits.shape)
print("Token terakhir prediksi:", outputs.logits[0, -1, :5])