Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| # from pathlib import Path | |
| from src.services.transformer import TinyTransformer | |
| # Internal constants for file paths | |
| # _MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model") | |
| # _VOCAB_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") | |
| _MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "model", "rellow-2.pt") | |
| _VOCAB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "vocab.json") | |
| # Internal device selection | |
| _DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| # Save the model state and vocabulary to disk. | |
| # def save_model(model, vocab): | |
| # # Create necessary directories | |
| # os.makedirs(_MODEL_DIR, exist_ok=True) | |
| # os.makedirs(_VOCAB_DIR, exist_ok=True) | |
| # # Save model state | |
| # torch.save(model.state_dict(), _MODEL_PATH) | |
| # # Save vocabulary | |
| # with open(_VOCAB_PATH, "w", encoding="utf-8") as f: | |
| # json.dump(vocab, f, ensure_ascii=False, indent=2) | |
| # print(f"Model saved to {_MODEL_PATH}") | |
| # print(f"Vocabulary saved to {_VOCAB_PATH}") | |
| # Load the model and its vocabulary from disk. | |
| def load_model(): | |
| # Load vocabulary | |
| with open(_VOCAB_PATH, "r", encoding="utf-8") as f: | |
| vocab = json.load(f) | |
| inv_vocab = {int(v): k for k, v in vocab.items()} | |
| # Initialize and load model | |
| model = TinyTransformer(vocab_size=len(vocab)).to(_DEVICE) | |
| model.load_state_dict(torch.load(_MODEL_PATH, map_location=_DEVICE)) | |
| model.eval() | |
| return model, vocab, inv_vocab | |
| # Get the device being used for model operations. | |
| def get_device(): | |
| return _DEVICE | |