UX fix: demo.py (download instructions / safetensors / pickle-free demo)
Browse files
demo.py
CHANGED
|
@@ -26,13 +26,17 @@ import numpy as np
|
|
| 26 |
import torch
|
| 27 |
|
| 28 |
ROOT = Path(__file__).resolve().parent
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
sys.path.insert(0, str(ROOT))
|
| 31 |
from yaz import YazConfig, YazLM
|
| 32 |
from yaz.semantic_router import SemanticRouter
|
| 33 |
from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
|
| 34 |
|
| 35 |
-
CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt"
|
|
|
|
|
|
|
| 36 |
DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
|
| 37 |
|
| 38 |
|
|
@@ -42,10 +46,22 @@ def ids(s):
|
|
| 42 |
|
| 43 |
class YazDemo:
|
| 44 |
def __init__(self, threshold=DEFAULT_THRESHOLD):
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
self.threshold = threshold
|
| 50 |
# capitals (full word) for resolving --edit RHS and for nicer output
|
| 51 |
self.capof = {}
|
|
|
|
| 26 |
import torch
|
| 27 |
|
| 28 |
ROOT = Path(__file__).resolve().parent
|
| 29 |
+
_emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "")
|
| 30 |
+
if _emb_path: # only add a real path (avoid inserting "" == cwd)
|
| 31 |
+
sys.path.insert(0, _emb_path)
|
| 32 |
sys.path.insert(0, str(ROOT))
|
| 33 |
from yaz import YazConfig, YazLM
|
| 34 |
from yaz.semantic_router import SemanticRouter
|
| 35 |
from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
|
| 36 |
|
| 37 |
+
CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt" # PyTorch pickle (fallback)
|
| 38 |
+
SAFETENSORS = ROOT / "model.safetensors" # recommended, pickle-free
|
| 39 |
+
META = ROOT / "yaz_meta.json"
|
| 40 |
DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
|
| 41 |
|
| 42 |
|
|
|
|
| 46 |
|
| 47 |
class YazDemo:
|
| 48 |
def __init__(self, threshold=DEFAULT_THRESHOLD):
|
| 49 |
+
# Prefer the pickle-free safetensors + JSON sidecar (the recommended artifact);
|
| 50 |
+
# fall back to the PyTorch pickle checkpoint if they're not present.
|
| 51 |
+
import json
|
| 52 |
+
if SAFETENSORS.exists() and META.exists():
|
| 53 |
+
from safetensors.torch import load_file
|
| 54 |
+
meta = json.loads(META.read_text())
|
| 55 |
+
self.cfg = YazConfig(**meta["cfg"])
|
| 56 |
+
self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS)))
|
| 57 |
+
self.model.eval()
|
| 58 |
+
self.c2i = meta["country_to_target_atom"]
|
| 59 |
+
else:
|
| 60 |
+
ck = torch.load(CKPT, map_location="cpu", weights_only=False)
|
| 61 |
+
self.cfg = YazConfig(**ck["cfg"])
|
| 62 |
+
self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
|
| 63 |
+
self.c2i = ck["country_to_target_atom"]
|
| 64 |
+
self.order = list(self.c2i.keys())
|
| 65 |
self.threshold = threshold
|
| 66 |
# capitals (full word) for resolving --edit RHS and for nicer output
|
| 67 |
self.capof = {}
|