kgrabko commited on
Commit
c109c31
·
verified ·
1 Parent(s): 224631b

Upload TestLoadModel.py

Browse files
Files changed (1) hide show
  1. TestLoadModel.py +69 -0
TestLoadModel.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ # ========================================
5
+ # Конфигурация (должна совпадать с моделью)
6
+ # ========================================
7
+ VOCAB_SIZE = 50257
8
+ MODEL_DIM = 768
9
+ NUM_LAYERS = 8
10
+ NUM_HEADS = 8
11
+ TRAIN_SEQ_LEN = 256
12
+ HEAD_DIM = MODEL_DIM // NUM_HEADS
13
+
14
+ # Путь к сохраненному файлу
15
+ JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt")
16
+
17
+ # Проверка устройства (для соответствия JIT-трассировке)
18
+ if torch.cuda.is_available():
19
+ device = torch.device("cuda")
20
+ elif hasattr(torch, 'hip') and torch.hip.is_available():
21
+ device = torch.device("cuda")
22
+ else:
23
+ device = torch.device("cpu")
24
+
25
+ def test_jit_model():
26
+ """Загружает и тестирует модель TorchScript."""
27
+
28
+ if not JIT_SAVE_PATH.exists():
29
+ print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}")
30
+ return
31
+
32
+ print(f"--- Тестирование TorchScript (JIT) модели ---")
33
+ print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...")
34
+
35
+ try:
36
+ # 1. Загрузка трассированной модели
37
+ # torch.jit.load загружает модель и ее веса.
38
+ loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device)
39
+ loaded_jit_model.eval()
40
+
41
+ # 2. Создание тестового ввода
42
+ # Входные данные должны соответствовать конфигурации, использованной при трассировке (T=256, B=1).
43
+ test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long)
44
+
45
+ # 3. Выполнение инференса
46
+ with torch.no_grad():
47
+ # Поскольку мы трассировали обертку NoCache, модель принимает только один вход (input_ids)
48
+ jit_logits = loaded_jit_model(test_input)
49
+
50
+ # 4. Проверки
51
+ expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE])
52
+
53
+ assert jit_logits.shape == expected_shape, (
54
+ f"Неверная форма вывода. Ожидалось: {expected_shape}, "
55
+ f"Получено: {jit_logits.shape}"
56
+ )
57
+
58
+ print("\n✅ Тест успешно пройден!")
59
+ print(f"Модель JIT загружена и работает корректно.")
60
+ print(f"Форма логитов: {jit_logits.shape}")
61
+ print(f"Устройство: {jit_logits.device}")
62
+
63
+ except Exception as e:
64
+ print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}")
65
+ print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ test_jit_model()