TilelliLab commited on
Commit
e7289a8
·
verified ·
1 Parent(s): dd920f9

UX fix: demo.py (download instructions / safetensors / pickle-free demo)

Browse files
Files changed (1) hide show
  1. demo.py +22 -6
demo.py CHANGED
@@ -26,13 +26,17 @@ import numpy as np
26
  import torch
27
 
28
  ROOT = Path(__file__).resolve().parent
29
- sys.path.insert(0, os.environ.get("YAZ_EMBEDDER_PATH", ""))
 
 
30
  sys.path.insert(0, str(ROOT))
31
  from yaz import YazConfig, YazLM
32
  from yaz.semantic_router import SemanticRouter
33
  from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
34
 
35
- CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt"
 
 
36
  DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
37
 
38
 
@@ -42,10 +46,22 @@ def ids(s):
42
 
43
  class YazDemo:
44
  def __init__(self, threshold=DEFAULT_THRESHOLD):
45
- ck = torch.load(CKPT, map_location="cpu", weights_only=False)
46
- self.cfg = YazConfig(**ck["cfg"])
47
- self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
48
- self.c2i = ck["country_to_target_atom"]; self.order = list(self.c2i.keys())
 
 
 
 
 
 
 
 
 
 
 
 
49
  self.threshold = threshold
50
  # capitals (full word) for resolving --edit RHS and for nicer output
51
  self.capof = {}
 
26
  import torch
27
 
28
  ROOT = Path(__file__).resolve().parent
29
+ _emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "")
30
+ if _emb_path: # only add a real path (avoid inserting "" == cwd)
31
+ sys.path.insert(0, _emb_path)
32
  sys.path.insert(0, str(ROOT))
33
  from yaz import YazConfig, YazLM
34
  from yaz.semantic_router import SemanticRouter
35
  from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
36
 
37
+ CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt" # PyTorch pickle (fallback)
38
+ SAFETENSORS = ROOT / "model.safetensors" # recommended, pickle-free
39
+ META = ROOT / "yaz_meta.json"
40
  DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
41
 
42
 
 
46
 
47
  class YazDemo:
48
  def __init__(self, threshold=DEFAULT_THRESHOLD):
49
+ # Prefer the pickle-free safetensors + JSON sidecar (the recommended artifact);
50
+ # fall back to the PyTorch pickle checkpoint if they're not present.
51
+ import json
52
+ if SAFETENSORS.exists() and META.exists():
53
+ from safetensors.torch import load_file
54
+ meta = json.loads(META.read_text())
55
+ self.cfg = YazConfig(**meta["cfg"])
56
+ self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS)))
57
+ self.model.eval()
58
+ self.c2i = meta["country_to_target_atom"]
59
+ else:
60
+ ck = torch.load(CKPT, map_location="cpu", weights_only=False)
61
+ self.cfg = YazConfig(**ck["cfg"])
62
+ self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
63
+ self.c2i = ck["country_to_target_atom"]
64
+ self.order = list(self.c2i.keys())
65
  self.threshold = threshold
66
  # capitals (full word) for resolving --edit RHS and for nicer output
67
  self.capof = {}