File size: 3,106 Bytes
c109c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from pathlib import Path

# ========================================
# Конфигурация (должна совпадать с моделью)
# ========================================
VOCAB_SIZE = 50257
MODEL_DIM = 768
NUM_LAYERS = 8     
NUM_HEADS = 8
TRAIN_SEQ_LEN = 256
HEAD_DIM = MODEL_DIM // NUM_HEADS

# Путь к сохраненному файлу
JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt")

# Проверка устройства (для соответствия JIT-трассировке)
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch, 'hip') and torch.hip.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

def test_jit_model():
    """Загружает и тестирует модель TorchScript."""
    
    if not JIT_SAVE_PATH.exists():
        print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}")
        return

    print(f"--- Тестирование TorchScript (JIT) модели ---")
    print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...")

    try:
        # 1. Загрузка трассированной модели
        # torch.jit.load загружает модель и ее веса.
        loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device)
        loaded_jit_model.eval()
        
        # 2. Создание тестового ввода
        # Входные данные должны соответствовать конфигурации, использованной при трассировке (T=256, B=1).
        test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long)
        
        # 3. Выполнение инференса
        with torch.no_grad():
            # Поскольку мы трассировали обертку NoCache, модель принимает только один вход (input_ids)
            jit_logits = loaded_jit_model(test_input)
        
        # 4. Проверки
        expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE])

        assert jit_logits.shape == expected_shape, (
            f"Неверная форма вывода. Ожидалось: {expected_shape}, "
            f"Получено: {jit_logits.shape}"
        )

        print("\n✅ Тест успешно пройден!")
        print(f"Модель JIT загружена и работает корректно.")
        print(f"Форма логитов: {jit_logits.shape}")
        print(f"Устройство: {jit_logits.device}")

    except Exception as e:
        print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}")
        print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.")


if __name__ == "__main__":
    test_jit_model()