|
|
""" |
|
|
export.py |
|
|
|
|
|
Exporting a trained MiniTransformer PyTorch model to TorchScript format for inference. |
|
|
|
|
|
Steps: |
|
|
1. Loading the MiniTransformer model architecture. |
|
|
2. Loading trained weights from 'phish_model.pt'. |
|
|
3. Converting the model to TorchScript using torch.jit.script. |
|
|
4. Saving the TorchScript model as 'phish_model_ts.pt'. |
|
|
5. Printing the file size of the exported model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
|
|
|
project_root = Path(__file__).resolve().parents[2] |
|
|
sys.path.insert(0, str(project_root / "src")) |
|
|
|
|
|
from model.model import MiniTransformer |
|
|
|
|
|
def export(): |
|
|
""" |
|
|
Exporting the trained MiniTransformer to TorchScript. |
|
|
|
|
|
Requirements: |
|
|
- Having 'phish_model.pt' existing in the project root. |
|
|
|
|
|
Output: |
|
|
- Saving 'phish_model_ts.pt' in the project root. |
|
|
- Printing the file size of the exported model in KB. |
|
|
""" |
|
|
|
|
|
|
|
|
model = MiniTransformer() |
|
|
|
|
|
|
|
|
model_path = project_root / "models" / "phish_model.pt" |
|
|
if not model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
model.load_state_dict(torch.load(model_path, map_location="cpu")) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
scripted = torch.jit.script(model) |
|
|
|
|
|
|
|
|
output_path = project_root / "models" / "phish_model.pt" |
|
|
scripted.save(output_path) |
|
|
|
|
|
|
|
|
size_kb = len(scripted.save_to_buffer()) / 1024 |
|
|
print(f"Exported TorchScript model to '{output_path}' | size: {size_kb:.1f} KB") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
export() |