JiRack_empty / TestLoadModel.py
kgrabko's picture
Upload TestLoadModel.py
c109c31 verified
import torch
from pathlib import Path
# ========================================
# Конфигурация (должна совпадать с моделью)
# ========================================
VOCAB_SIZE = 50257
MODEL_DIM = 768
NUM_LAYERS = 8
NUM_HEADS = 8
TRAIN_SEQ_LEN = 256
HEAD_DIM = MODEL_DIM // NUM_HEADS
# Путь к сохраненному файлу
JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt")
# Проверка устройства (для соответствия JIT-трассировке)
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch, 'hip') and torch.hip.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
def test_jit_model():
"""Загружает и тестирует модель TorchScript."""
if not JIT_SAVE_PATH.exists():
print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}")
return
print(f"--- Тестирование TorchScript (JIT) модели ---")
print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...")
try:
# 1. Загрузка трассированной модели
# torch.jit.load загружает модель и ее веса.
loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device)
loaded_jit_model.eval()
# 2. Создание тестового ввода
# Входные данные должны соответствовать конфигурации, использованной при трассировке (T=256, B=1).
test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long)
# 3. Выполнение инференса
with torch.no_grad():
# Поскольку мы трассировали обертку NoCache, модель принимает только один вход (input_ids)
jit_logits = loaded_jit_model(test_input)
# 4. Проверки
expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE])
assert jit_logits.shape == expected_shape, (
f"Неверная форма вывода. Ожидалось: {expected_shape}, "
f"Получено: {jit_logits.shape}"
)
print("\n✅ Тест успешно пройден!")
print(f"Модель JIT загружена и работает корректно.")
print(f"Форма логитов: {jit_logits.shape}")
print(f"Устройство: {jit_logits.device}")
except Exception as e:
print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}")
print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.")
if __name__ == "__main__":
test_jit_model()