[fix]: lowercase input and end with a period.
Browse files
app.py
CHANGED
|
@@ -45,6 +45,7 @@ def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights
|
|
| 45 |
@torch.inference_mode()
|
| 46 |
def run(input: str) -> Tuple[str, plt.Figure]:
|
| 47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
|
|
|
| 48 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
| 49 |
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
| 50 |
output = target_spm.decode(output.detach().tolist())
|
|
|
|
| 45 |
@torch.inference_mode()
|
| 46 |
def run(input: str) -> Tuple[str, plt.Figure]:
|
| 47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
| 48 |
+
input = input.lower().strip().rstrip(".") + "."
|
| 49 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
| 50 |
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
| 51 |
output = target_spm.decode(output.detach().tolist())
|