File size: 1,783 Bytes
7500cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
"""
export.py
Exporting a trained MiniTransformer PyTorch model to TorchScript format for inference.
Steps:
1. Loading the MiniTransformer model architecture.
2. Loading trained weights from 'phish_model.pt'.
3. Converting the model to TorchScript using torch.jit.script.
4. Saving the TorchScript model as 'phish_model_ts.pt'.
5. Printing the file size of the exported model.
"""
import torch
from pathlib import Path
import sys
# Adding project root to Python path for importing modules
project_root = Path(__file__).resolve().parents[2] # go up to project root
sys.path.insert(0, str(project_root / "src"))
from model.model import MiniTransformer
def export():
"""
Exporting the trained MiniTransformer to TorchScript.
Requirements:
- Having 'phish_model.pt' existing in the project root.
Output:
- Saving 'phish_model_ts.pt' in the project root.
- Printing the file size of the exported model in KB.
"""
# Initializing the model
model = MiniTransformer()
# Loading trained weights
model_path = project_root / "models" / "phish_model.pt"
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
# Setting model to evaluation mode
model.eval()
# Converting the model to TorchScript
with torch.no_grad():
scripted = torch.jit.script(model)
# Saving the TorchScript model
output_path = project_root / "models" / "phish_model.pt"
scripted.save(output_path)
# Printing the model file size
size_kb = len(scripted.save_to_buffer()) / 1024
print(f"Exported TorchScript model to '{output_path}' | size: {size_kb:.1f} KB")
if __name__ == "__main__":
export() |