Frenchizer commited on
Commit
36a1938
·
verified ·
1 Parent(s): 5ad0807

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -74,7 +74,7 @@ def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05):
74
  return top_contexts if top_contexts else ["general"]
75
 
76
  def translate_text(input_text):
77
- tokenized_input = tokenizer(
78
  input_text, return_tensors="np",
79
  padding=True, truncation=True, max_length=512
80
  )
@@ -82,7 +82,7 @@ def translate_text(input_text):
82
  input_ids = tokenized_input["input_ids"].astype(np.int64)
83
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
84
 
85
- decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
86
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
87
 
88
  for _ in range(512):
@@ -101,10 +101,10 @@ def translate_text(input_text):
101
  [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
102
  )
103
 
104
- if next_token_id == tokenizer.eos_token_id:
105
  break
106
 
107
- return tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
108
 
109
  def process_request(input_text):
110
  context = detect_context(input_text)
 
74
  return top_contexts if top_contexts else ["general"]
75
 
76
  def translate_text(input_text):
77
+ tokenized_input = translation_tokenizer(
78
  input_text, return_tensors="np",
79
  padding=True, truncation=True, max_length=512
80
  )
 
82
  input_ids = tokenized_input["input_ids"].astype(np.int64)
83
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
84
 
85
+ decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
86
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
87
 
88
  for _ in range(512):
 
101
  [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
102
  )
103
 
104
+ if next_token_id == translation_tokenizer.eos_token_id:
105
  break
106
 
107
+ return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
108
 
109
  def process_request(input_text):
110
  context = detect_context(input_text)