naics_embeddings / api /load_artifacts.py
Joseph Warth
Load y6 title lookup from JSON artifact
7ca09fd
from pathlib import Path
import pickle
import torch
import torch.nn as nn
import pandas as pd
import json
from sentence_transformers import SentenceTransformer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HIDDEN_DIM = 768
DROPOUT = 0.1
class FlatEmbedMLP(nn.Module):
def __init__(self, input_dim, n_classes, hidden_dim=HIDDEN_DIM, dropout=DROPOUT):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, n_classes),
)
def forward(self, x):
return self.net(x)
_artifacts = None
def _artifacts_root():
return Path(__file__).resolve().parents[1] / "training" / "artifacts"
def load_artifacts():
global _artifacts
if _artifacts is not None:
return _artifacts
artifacts_dir = _artifacts_root()
with open(artifacts_dir / "label_maps" / "label_maps_embed.pkl", "rb") as f:
label_maps = pickle.load(f)
with open(artifacts_dir / "embedder" / "embed_metadata.pkl", "rb") as f:
embed_metadata = pickle.load(f)
embedder_model_name = embed_metadata["model_name"]
embedder = SentenceTransformer(embedder_model_name, device=DEVICE)
n_classes = len(label_maps["y6"]["classes"])
input_dim = int(embed_metadata["embedding_dim"])
model = FlatEmbedMLP(
input_dim=input_dim,
n_classes=n_classes,
).to(DEVICE)
model_path = artifacts_dir / "models" / "flat_embed_best.pt"
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
title_lookup_path = artifacts_dir / "label_maps" / "y6_title_lookup.json"
with open(title_lookup_path, "r") as f:
y6_title_lookup = json.load(f)
_artifacts = {
"device": DEVICE,
"embedder": embedder,
"model": model,
"label_maps": label_maps,
"embed_metadata": embed_metadata,
"y6_title_lookup": y6_title_lookup,
}
return _artifacts