Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -74,7 +74,7 @@ def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05):
|
|
| 74 |
return top_contexts if top_contexts else ["general"]
|
| 75 |
|
| 76 |
def translate_text(input_text):
|
| 77 |
-
tokenized_input =
|
| 78 |
input_text, return_tensors="np",
|
| 79 |
padding=True, truncation=True, max_length=512
|
| 80 |
)
|
|
@@ -82,7 +82,7 @@ def translate_text(input_text):
|
|
| 82 |
input_ids = tokenized_input["input_ids"].astype(np.int64)
|
| 83 |
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
|
| 84 |
|
| 85 |
-
decoder_start_token_id =
|
| 86 |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
|
| 87 |
|
| 88 |
for _ in range(512):
|
|
@@ -101,10 +101,10 @@ def translate_text(input_text):
|
|
| 101 |
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
|
| 102 |
)
|
| 103 |
|
| 104 |
-
if next_token_id ==
|
| 105 |
break
|
| 106 |
|
| 107 |
-
return
|
| 108 |
|
| 109 |
def process_request(input_text):
|
| 110 |
context = detect_context(input_text)
|
|
|
|
| 74 |
return top_contexts if top_contexts else ["general"]
|
| 75 |
|
| 76 |
def translate_text(input_text):
|
| 77 |
+
tokenized_input = translation_tokenizer(
|
| 78 |
input_text, return_tensors="np",
|
| 79 |
padding=True, truncation=True, max_length=512
|
| 80 |
)
|
|
|
|
| 82 |
input_ids = tokenized_input["input_ids"].astype(np.int64)
|
| 83 |
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
|
| 84 |
|
| 85 |
+
decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
|
| 86 |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
|
| 87 |
|
| 88 |
for _ in range(512):
|
|
|
|
| 101 |
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
|
| 102 |
)
|
| 103 |
|
| 104 |
+
if next_token_id == translation_tokenizer.eos_token_id:
|
| 105 |
break
|
| 106 |
|
| 107 |
+
return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
|
| 108 |
|
| 109 |
def process_request(input_text):
|
| 110 |
context = detect_context(input_text)
|