# Copyright (c) 2025 CMS Manhattan # All rights reserved. # Author: Konstantin Vladimirovich Grabko # Email: grabko@cmsmanhattan.com # Phone: +1(516)777-0945 # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, version 3 of the License. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Additional terms: # Any commercial use or distribution of this software or derivative works # requires explicit written permission from the copyright holder. import torch from pathlib import Path # ======================================== # Конфигурация (должна совпадать с моделью) # ======================================== VOCAB_SIZE = 50257 MODEL_DIM = 768 NUM_LAYERS = 8 NUM_HEADS = 8 TRAIN_SEQ_LEN = 256 HEAD_DIM = MODEL_DIM // NUM_HEADS # Путь к сохраненному файлу JIT_SAVE_PATH = Path("models/gpt_pytorch_L8_H8_base.script.pt") # Проверка устройства (для соответствия JIT-трассировке) if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch, 'hip') and torch.hip.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") def test_jit_model(): """Загружает и тестирует модель TorchScript.""" if not JIT_SAVE_PATH.exists(): print(f"🚨 Ошибка: Файл JIT-модели не найден по пути: {JIT_SAVE_PATH}") return print(f"--- Тестирование TorchScript (JIT) модели ---") print(f"Загрузка модели с {JIT_SAVE_PATH} на {device}...") try: # 1. Загрузка трассированной модели # torch.jit.load загружает модель и ее веса. loaded_jit_model = torch.jit.load(str(JIT_SAVE_PATH), map_location=device) loaded_jit_model.eval() # 2. Создание тестового ввода # Входные данные должны соответствовать конфигурации, использованной при трассировке (T=256, B=1). test_input = torch.randint(0, VOCAB_SIZE, (1, TRAIN_SEQ_LEN), device=device, dtype=torch.long) # 3. Выполнение инференса with torch.no_grad(): # Поскольку мы трассировали обертку NoCache, модель принимает только один вход (input_ids) jit_logits = loaded_jit_model(test_input) # 4. Проверки expected_shape = torch.Size([1, TRAIN_SEQ_LEN, VOCAB_SIZE]) assert jit_logits.shape == expected_shape, ( f"Неверная форма вывода. Ожидалось: {expected_shape}, " f"Получено: {jit_logits.shape}" ) print("\n✅ Тест успешно пройден!") print(f"Модель JIT загружена и работает корректно.") print(f"Форма логитов: {jit_logits.shape}") print(f"Устройство: {jit_logits.device}") except Exception as e: print(f"\n🚨 Критическая ошибка при тестировании JIT-модели: {e}") print("Проверьте, что конфигурация (VOCAB_SIZE, MODEL_DIM, T) соответствует трассировке.") if __name__ == "__main__": test_jit_model()