File size: 1,634 Bytes
e663138
 
 
 
738d2af
e663138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import json
import torch
# from pathlib import Path
from src.services.transformer import TinyTransformer

# Internal constants for file paths
# _MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model")
# _VOCAB_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
_MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model", "rellow-2.pt")
_VOCAB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "vocab.json")

# Internal device selection
_DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Save the model state and vocabulary to disk.
# def save_model(model, vocab):
#   # Create necessary directories
#   os.makedirs(_MODEL_DIR, exist_ok=True)
#   os.makedirs(_VOCAB_DIR, exist_ok=True)
  
#   # Save model state
#   torch.save(model.state_dict(), _MODEL_PATH)
  
#   # Save vocabulary
#   with open(_VOCAB_PATH, "w", encoding="utf-8") as f:
#     json.dump(vocab, f, ensure_ascii=False, indent=2)
  
#   print(f"Model saved to {_MODEL_PATH}")
#   print(f"Vocabulary saved to {_VOCAB_PATH}")

# Load the model and its vocabulary from disk.
def load_model():
  # Load vocabulary
  with open(_VOCAB_PATH, "r", encoding="utf-8") as f:
    vocab = json.load(f)
  inv_vocab = {int(v): k for k, v in vocab.items()}
  
  # Initialize and load model
  model = TinyTransformer(vocab_size=len(vocab)).to(_DEVICE)
  model.load_state_dict(torch.load(_MODEL_PATH, map_location=_DEVICE))
  model.eval()
  
  return model, vocab, inv_vocab

# Get the device being used for model operations.
def get_device():
  return _DEVICE