| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| from cog import BasePredictor, Input | |
| class Predictor(BasePredictor): | |
| def setup(self): | |
| """Load the model and tokenizer into memory to make running multiple predictions efficient""" | |
| self.model = T5ForConditionalGeneration.from_pretrained("aaurelions/t5-grammar-corrector") | |
| self.tokenizer = T5Tokenizer.from_pretrained("aaurelions/t5-grammar-corrector") | |
| def predict(self, text: str = Input(description="Text to correct")) -> str: | |
| """Run a single prediction on the model""" | |
| input_text = "fix grammar: " + text | |
| input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids | |
| output_ids = self.model.generate(input_ids, max_length=128) | |
| corrected_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return corrected_text |