student2222333051 commited on
Commit
8e796c0
·
verified ·
1 Parent(s): 6ed32cf

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +19 -0
models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MarianMTModel, MarianTokenizer
2
+
3
+ cache = {}
4
+
5
+ def load_translation_model(model_name):
6
+ if model_name in cache:
7
+ return cache[model_name]
8
+
9
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
10
+ model = MarianMTModel.from_pretrained(model_name)
11
+
12
+ cache[model_name] = (model, tokenizer)
13
+ return model, tokenizer
14
+
15
+ def translate_text(text, model, tokenizer):
16
+ batch = tokenizer([text], return_tensors="pt", padding=True)
17
+ gen = model.generate(**batch)
18
+ result = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
19
+ return result