Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HIDDEN_DIM = 768 | |
| DROPOUT = 0.1 | |
| class FlatEmbedMLP(nn.Module): | |
| def __init__(self, input_dim, n_classes, hidden_dim=HIDDEN_DIM, dropout=DROPOUT): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, n_classes), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| _artifacts = None | |
| def _artifacts_root(): | |
| return Path(__file__).resolve().parents[1] / "training" / "artifacts" | |
| def load_artifacts(): | |
| global _artifacts | |
| if _artifacts is not None: | |
| return _artifacts | |
| artifacts_dir = _artifacts_root() | |
| with open(artifacts_dir / "label_maps" / "label_maps_embed.pkl", "rb") as f: | |
| label_maps = pickle.load(f) | |
| with open(artifacts_dir / "embedder" / "embed_metadata.pkl", "rb") as f: | |
| embed_metadata = pickle.load(f) | |
| embedder_model_name = embed_metadata["model_name"] | |
| embedder = SentenceTransformer(embedder_model_name, device=DEVICE) | |
| n_classes = len(label_maps["y6"]["classes"]) | |
| input_dim = int(embed_metadata["embedding_dim"]) | |
| model = FlatEmbedMLP( | |
| input_dim=input_dim, | |
| n_classes=n_classes, | |
| ).to(DEVICE) | |
| model_path = artifacts_dir / "models" / "flat_embed_best.pt" | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| title_lookup_path = artifacts_dir / "label_maps" / "y6_title_lookup.json" | |
| with open(title_lookup_path, "r") as f: | |
| y6_title_lookup = json.load(f) | |
| _artifacts = { | |
| "device": DEVICE, | |
| "embedder": embedder, | |
| "model": model, | |
| "label_maps": label_maps, | |
| "embed_metadata": embed_metadata, | |
| "y6_title_lookup": y6_title_lookup, | |
| } | |
| return _artifacts |