Update the prefix before being fed into the model
Browse files- handler.py +7 -2
handler.py
CHANGED
|
@@ -12,14 +12,18 @@ class EndpointHandler:
|
|
| 12 |
self.model.to(self.device)
|
| 13 |
|
| 14 |
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
| 16 |
inputs = self.tokenizer(
|
| 17 |
-
|
| 18 |
padding=True,
|
| 19 |
truncation=True,
|
| 20 |
max_length=512,
|
| 21 |
return_tensors="pt"
|
| 22 |
).to(self.device)
|
|
|
|
| 23 |
outputs = self.model.generate(
|
| 24 |
**inputs,
|
| 25 |
max_length=512,
|
|
@@ -28,6 +32,7 @@ class EndpointHandler:
|
|
| 28 |
num_return_sequences=num_return_sequences,
|
| 29 |
early_stopping=True
|
| 30 |
)
|
|
|
|
| 31 |
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 32 |
if num_return_sequences > 1:
|
| 33 |
grouped = [
|
|
|
|
| 12 |
self.model.to(self.device)
|
| 13 |
|
| 14 |
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
|
| 15 |
+
# Add the grammar correction prefix to each sentence
|
| 16 |
+
prefix = "correct grammar for this sentence: "
|
| 17 |
+
sentences_with_prefix = [prefix + s for s in sentences]
|
| 18 |
+
|
| 19 |
inputs = self.tokenizer(
|
| 20 |
+
sentences_with_prefix,
|
| 21 |
padding=True,
|
| 22 |
truncation=True,
|
| 23 |
max_length=512,
|
| 24 |
return_tensors="pt"
|
| 25 |
).to(self.device)
|
| 26 |
+
|
| 27 |
outputs = self.model.generate(
|
| 28 |
**inputs,
|
| 29 |
max_length=512,
|
|
|
|
| 32 |
num_return_sequences=num_return_sequences,
|
| 33 |
early_stopping=True
|
| 34 |
)
|
| 35 |
+
|
| 36 |
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 37 |
if num_return_sequences > 1:
|
| 38 |
grouped = [
|