| """Minimal end-to-end example: load the HF-hosted weights and translate. |
| |
| Run from the parent project directory (so `src` is importable): |
| python example.py --text "Hello world, how are you?" |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import sentencepiece as spm |
| import torch |
|
|
| |
| from src.model import Transformer |
| from src.inference.translate import batched_beam_search |
| from src.data.tokenizer import BOS_ID, EOS_ID, PAD_ID |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--weights", default="pytorch_model.bin") |
| ap.add_argument("--spm", default="sentencepiece.model") |
| ap.add_argument("--config", default="config.json") |
| ap.add_argument("--text", required=True, help="English sentence to translate.") |
| ap.add_argument("--beam", type=int, default=5) |
| ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = ap.parse_args() |
|
|
| cfg = json.loads(Path(args.config).read_text()) |
|
|
| model = Transformer( |
| vocab_size=cfg["vocab_size"], |
| d_model=cfg["d_model"], |
| n_heads=cfg["n_heads"], |
| n_encoder_layers=cfg["n_encoder_layers"], |
| n_decoder_layers=cfg["n_decoder_layers"], |
| d_ff=cfg["d_ff"], |
| dropout=0.0, |
| max_seq_len=cfg["max_seq_len"], |
| share_embeddings=cfg["share_embeddings"], |
| pad_idx=PAD_ID, |
| ).to(args.device) |
| model.load_state_dict(torch.load(args.weights, map_location=args.device)) |
| model.eval() |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.load(args.spm) |
|
|
| |
| ids = [BOS_ID] + sp.encode(args.text, out_type=int) + [EOS_ID] |
| src = torch.tensor([ids], dtype=torch.long, device=args.device) |
|
|
| hyp_ids = batched_beam_search( |
| model, src, beam_size=args.beam, max_len=cfg["max_seq_len"], length_penalty=1.0 |
| )[0] |
|
|
| |
| hyp_ids = [t for t in hyp_ids if t not in (BOS_ID, EOS_ID, PAD_ID)] |
| print(sp.decode(hyp_ids)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|