|
|
import torch
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VOCAB_SIZE = 50257
|
|
|
MODEL_DIM = 768
|
|
|
NUM_LAYERS = 8
|
|
|
NUM_HEADS = 8
|
|
|
TRAIN_SEQ_LEN = 256
|
|
|
HEAD_DIM = MODEL_DIM // NUM_HEADS
|
|
|
|
|
|
|
|
|
JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt")
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.device("cuda")
|
|
|
elif hasattr(torch, 'hip') and torch.hip.is_available():
|
|
|
device = torch.device("cuda")
|
|
|
else:
|
|
|
device = torch.device("cpu")
|
|
|
|
|
|
def test_jit_model():
|
|
|
"""Загружает и тестирует модель TorchScript."""
|
|
|
|
|
|
if not JIT_SAVE_PATH.exists():
|
|
|
print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}")
|
|
|
return
|
|
|
|
|
|
print(f"--- Тестирование TorchScript (JIT) модели ---")
|
|
|
print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...")
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device)
|
|
|
loaded_jit_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
jit_logits = loaded_jit_model(test_input)
|
|
|
|
|
|
|
|
|
expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE])
|
|
|
|
|
|
assert jit_logits.shape == expected_shape, (
|
|
|
f"Неверная форма вывода. Ожидалось: {expected_shape}, "
|
|
|
f"Получено: {jit_logits.shape}"
|
|
|
)
|
|
|
|
|
|
print("\n✅ Тест успешно пройден!")
|
|
|
print(f"Модель JIT загружена и работает корректно.")
|
|
|
print(f"Форма логитов: {jit_logits.shape}")
|
|
|
print(f"Устройство: {jit_logits.device}")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}")
|
|
|
print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_jit_model() |