sxtran commited on
Commit
2f16beb
·
verified ·
1 Parent(s): 7a92108

Update the prefix before being fed into the model

Browse files
Files changed (1) hide show
  1. handler.py +7 -2
handler.py CHANGED
@@ -12,14 +12,18 @@ class EndpointHandler:
12
  self.model.to(self.device)
13
 
14
  def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
15
- # Your existing paraphrase_batch logic
 
 
 
16
  inputs = self.tokenizer(
17
- sentences,
18
  padding=True,
19
  truncation=True,
20
  max_length=512,
21
  return_tensors="pt"
22
  ).to(self.device)
 
23
  outputs = self.model.generate(
24
  **inputs,
25
  max_length=512,
@@ -28,6 +32,7 @@ class EndpointHandler:
28
  num_return_sequences=num_return_sequences,
29
  early_stopping=True
30
  )
 
31
  decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
32
  if num_return_sequences > 1:
33
  grouped = [
 
12
  self.model.to(self.device)
13
 
14
  def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
15
+ # Add the grammar correction prefix to each sentence
16
+ prefix = "correct grammar for this sentence: "
17
+ sentences_with_prefix = [prefix + s for s in sentences]
18
+
19
  inputs = self.tokenizer(
20
+ sentences_with_prefix,
21
  padding=True,
22
  truncation=True,
23
  max_length=512,
24
  return_tensors="pt"
25
  ).to(self.device)
26
+
27
  outputs = self.model.generate(
28
  **inputs,
29
  max_length=512,
 
32
  num_return_sequences=num_return_sequences,
33
  early_stopping=True
34
  )
35
+
36
  decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
37
  if num_return_sequences > 1:
38
  grouped = [