LiamKhoaLe commited on
Commit
38513a2
·
1 Parent(s): b7090b2

Upd sentencepiece

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. vi/download.py +3 -1
requirements.txt CHANGED
@@ -13,3 +13,5 @@ ftfy
13
  langid
14
  transformers
15
  torch
 
 
 
13
  langid
14
  transformers
15
  torch
16
+ sentencepiece
17
+ sacremoses
vi/download.py CHANGED
@@ -9,6 +9,7 @@ import os
9
  import sys
10
  import logging
11
  from pathlib import Path
 
12
  from transformers import MarianMTModel, MarianTokenizer
13
 
14
  # Setup logging
@@ -56,7 +57,8 @@ def download_model(model_name: str = "Helsinki-NLP/opus-mt-en-vi", cache_dir: st
56
  logger.info("Testing model...")
57
  test_text = "Hello, how are you?"
58
  inputs = tokenizer(f">>vie<< {test_text}", return_tensors="pt")
59
- with model.eval():
 
60
  outputs = model.generate(**inputs, max_length=50, num_beams=4)
61
  translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
  logger.info(f"Test translation: '{test_text}' -> '{translated}'")
 
9
  import sys
10
  import logging
11
  from pathlib import Path
12
+ import torch
13
  from transformers import MarianMTModel, MarianTokenizer
14
 
15
  # Setup logging
 
57
  logger.info("Testing model...")
58
  test_text = "Hello, how are you?"
59
  inputs = tokenizer(f">>vie<< {test_text}", return_tensors="pt")
60
+ model.eval()
61
+ with torch.no_grad():
62
  outputs = model.generate(**inputs, max_length=50, num_beams=4)
63
  translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  logger.info(f"Test translation: '{test_text}' -> '{translated}'")