Upload TestLoadModel.py
Browse files- TestLoadModel.py +69 -0
TestLoadModel.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
# ========================================
|
| 5 |
+
# Конфигурация (должна совпадать с моделью)
|
| 6 |
+
# ========================================
|
| 7 |
+
VOCAB_SIZE = 50257
|
| 8 |
+
MODEL_DIM = 768
|
| 9 |
+
NUM_LAYERS = 8
|
| 10 |
+
NUM_HEADS = 8
|
| 11 |
+
TRAIN_SEQ_LEN = 256
|
| 12 |
+
HEAD_DIM = MODEL_DIM // NUM_HEADS
|
| 13 |
+
|
| 14 |
+
# Путь к сохраненному файлу
|
| 15 |
+
JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt")
|
| 16 |
+
|
| 17 |
+
# Проверка устройства (для соответствия JIT-трассировке)
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
device = torch.device("cuda")
|
| 20 |
+
elif hasattr(torch, 'hip') and torch.hip.is_available():
|
| 21 |
+
device = torch.device("cuda")
|
| 22 |
+
else:
|
| 23 |
+
device = torch.device("cpu")
|
| 24 |
+
|
| 25 |
+
def test_jit_model():
|
| 26 |
+
"""Загружает и тестирует модель TorchScript."""
|
| 27 |
+
|
| 28 |
+
if not JIT_SAVE_PATH.exists():
|
| 29 |
+
print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}")
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
print(f"--- Тестирование TorchScript (JIT) модели ---")
|
| 33 |
+
print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# 1. Загрузка трассированной модели
|
| 37 |
+
# torch.jit.load загружает модель и ее веса.
|
| 38 |
+
loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device)
|
| 39 |
+
loaded_jit_model.eval()
|
| 40 |
+
|
| 41 |
+
# 2. Создание тестового ввода
|
| 42 |
+
# Входные данные должны соответствовать конфигурации, использованной при трассировке (T=256, B=1).
|
| 43 |
+
test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long)
|
| 44 |
+
|
| 45 |
+
# 3. Выполнение инференса
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
# Поскольку мы трассировали обертку NoCache, модель принимает только один вход (input_ids)
|
| 48 |
+
jit_logits = loaded_jit_model(test_input)
|
| 49 |
+
|
| 50 |
+
# 4. Проверки
|
| 51 |
+
expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE])
|
| 52 |
+
|
| 53 |
+
assert jit_logits.shape == expected_shape, (
|
| 54 |
+
f"Неверная форма вывода. Ожидалось: {expected_shape}, "
|
| 55 |
+
f"Получено: {jit_logits.shape}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
print("\n✅ Тест успешно пройден!")
|
| 59 |
+
print(f"Модель JIT загружена и работает корректно.")
|
| 60 |
+
print(f"Форма логитов: {jit_logits.shape}")
|
| 61 |
+
print(f"Устройство: {jit_logits.device}")
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}")
|
| 65 |
+
print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
test_jit_model()
|