File size: 1,368 Bytes
b14638e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Load Yaz from the safetensors weights + JSON sidecar in this Hugging Face repo.

    from load_yaz import load_yaz
    model, cfg, meta = load_yaz()          # uses files in this directory
    # `model` is a YazLM (eval mode); `meta` has cfg + the country->atom / ->capital maps.

No pickle is loaded — weights come from `model.safetensors`. The `yaz/` package (model code) is
vendored alongside this file. See `demo.py` for routing + abstention + live-edit usage.
"""
from __future__ import annotations
import json
import os

import torch
from safetensors.torch import load_file

from yaz import YazConfig, YazLM

HERE = os.path.dirname(os.path.abspath(__file__))


def load_yaz(weights="model.safetensors", meta="yaz_meta.json"):
    with open(os.path.join(HERE, meta), "r", encoding="utf-8") as f:
        m = json.load(f)
    cfg = YazConfig(**m["cfg"])
    model = YazLM(cfg)
    state = load_file(os.path.join(HERE, weights))
    model.load_state_dict(state)
    model.eval()
    return model, cfg, m


if __name__ == "__main__":
    model, cfg, meta = load_yaz()
    n = sum(p.numel() for p in model.parameters())
    print(f"Loaded Yaz: {n:,} parameters | {len(meta['country_to_target_atom'])} fact-atoms")
    print("France -> atom", meta["country_to_target_atom"]["France"],
          "| capital first byte", meta["country_to_capital_first"]["France"])