Spaces:
Sleeping
Sleeping
File size: 1,634 Bytes
e663138 738d2af e663138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import json
import torch
# from pathlib import Path
from src.services.transformer import TinyTransformer
# Internal constants for file paths
# _MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model")
# _VOCAB_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
_MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model", "rellow-2.pt")
_VOCAB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "vocab.json")
# Internal device selection
_DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Save the model state and vocabulary to disk.
# def save_model(model, vocab):
# # Create necessary directories
# os.makedirs(_MODEL_DIR, exist_ok=True)
# os.makedirs(_VOCAB_DIR, exist_ok=True)
# # Save model state
# torch.save(model.state_dict(), _MODEL_PATH)
# # Save vocabulary
# with open(_VOCAB_PATH, "w", encoding="utf-8") as f:
# json.dump(vocab, f, ensure_ascii=False, indent=2)
# print(f"Model saved to {_MODEL_PATH}")
# print(f"Vocabulary saved to {_VOCAB_PATH}")
# Load the model and its vocabulary from disk.
def load_model():
# Load vocabulary
with open(_VOCAB_PATH, "r", encoding="utf-8") as f:
vocab = json.load(f)
inv_vocab = {int(v): k for k, v in vocab.items()}
# Initialize and load model
model = TinyTransformer(vocab_size=len(vocab)).to(_DEVICE)
model.load_state_dict(torch.load(_MODEL_PATH, map_location=_DEVICE))
model.eval()
return model, vocab, inv_vocab
# Get the device being used for model operations.
def get_device():
return _DEVICE
|