TiberiuCristianLeon commited on
Commit
18293ea
·
verified ·
1 Parent(s): 544f6db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -69,8 +69,16 @@ class Translators:
69
  # model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
70
  # model.half() # recommended for GPU
71
  model.eval()
 
72
  # Translating from one or several sentences to a sole language
73
  src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
 
 
 
 
 
 
 
74
  # src_tokens = src_tokens.to(self.device)
75
  # generated_tokens = model.generate(src_tokens)
76
  # return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
@@ -78,9 +86,10 @@ class Translators:
78
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
79
  # generated_tokens = model.generate(src_tokens.to(self.device))
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
- with torch.no_grad():
82
  generated_tokens = model.generate(src_tokens)
83
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
 
69
  # model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
70
  # model.half() # recommended for GPU
71
  model.eval()
72
+ model.float()
73
  # Translating from one or several sentences to a sole language
74
  src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
75
+ # src_tokens may be a torch.Tensor or dict depending on tokenizer; ensure it's a tensor
76
+ if isinstance(src_tokens, torch.Tensor):
77
+ src_tokens = src_tokens.to(self.device)
78
+ else:
79
+ # if tokenizer returns dict-like inputs (input_ids, attention_mask)
80
+ for k, v in src_tokens.items():
81
+ src_tokens[k] = v.to(self.device)
82
  # src_tokens = src_tokens.to(self.device)
83
  # generated_tokens = model.generate(src_tokens)
84
  # return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
86
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
87
  # generated_tokens = model.generate(src_tokens.to(self.device))
88
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
89
+ with torch.inference_mode(): # no_grad inference_mode
90
  generated_tokens = model.generate(src_tokens)
91
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
92
+ return result
93
 
94
  def hplt(self, opus = False):
95
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']