bifrost-1.2b / run_example.py
NodeNester's picture
Super-squash branch 'main' using huggingface_hub
7bf218e
Raw
History Blame Contribute Delete
1.34 kB
"""
Minimal end-to-end example for the Nordic Translator.
Usage:
python run_example.py "Hello, how are you?" sv
python run_example.py "Var bor du?" en
Two ways to run are shown:
1. Standalone (pure torch, fast KV-cached) -> NordicTranslator.translate()
2. HuggingFace AutoModel (trust_remote_code) -> model.generate()
"""
import sys
import torch
import sentencepiece as spm
from modeling_nordic import NordicTranslator
HERE = __import__("os").path.dirname(__file__)
WEIGHTS = HERE + "/model.safetensors"
TOKENIZER = HERE + "/nordic_unigram_65k.model"
BOS, EOS, EOS_SRC = 1, 2, 65007
LANG = {"en": 65000, "sv": 65001, "da": 65002, "nb": 65003,
"nn": 65004, "fi": 65005, "is": 65006}
def main():
text = sys.argv[1] if len(sys.argv) > 1 else "Hello, how are you today?"
tgt = sys.argv[2] if len(sys.argv) > 2 else "sv"
assert tgt in LANG, f"target must be one of {list(LANG)}"
device = "cuda" if torch.cuda.is_available() else "cpu"
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER)
model = NordicTranslator.from_checkpoint(WEIGHTS, device=device, dtype=torch.bfloat16)
src_ids = sp.encode(text, out_type=int)
out_ids = model.translate(src_ids, LANG[tgt], bos=BOS, eos=EOS, eos_src=EOS_SRC)
print(f"[{tgt}] {sp.decode(out_ids)}")
if __name__ == "__main__":
main()