Buckets:
| from typing import Dict, List, Any | |
| import torch | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| import re | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| """ | |
| Initialize the endpoint handler with the model and tokenizer. | |
| :param path: Path to the model weights | |
| """ | |
| # Determine the device | |
| self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Load tokenizer and model | |
| self.tokenizer = PegasusTokenizer.from_pretrained(path) | |
| self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device) | |
| def split_into_paragraphs(self, text: str) -> List[str]: | |
| """ | |
| Split text into paragraphs while preserving empty lines. | |
| :param text: Input text | |
| :return: List of paragraphs | |
| """ | |
| paragraphs = text.split('\n\n') | |
| return [p.strip() for p in paragraphs if p.strip()] | |
| def split_into_sentences(self, paragraph: str) -> List[str]: | |
| """ | |
| Split paragraph into sentences using regex. | |
| :param paragraph: Input paragraph | |
| :return: List of sentences | |
| """ | |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def get_response(self, input_text: str, num_return_sequences: int = 1) -> str: | |
| """ | |
| Generate paraphrased text for a single input. | |
| :param input_text: Input sentence to paraphrase | |
| :param num_return_sequences: Number of alternative paraphrases to generate | |
| :return: Paraphrased text | |
| """ | |
| batch = self.tokenizer.prepare_seq2seq_batch( | |
| [input_text], | |
| truncation=True, | |
| padding='longest', | |
| max_length=80, | |
| return_tensors="pt" | |
| ).to(self.torch_device) | |
| translated = self.model.generate( | |
| **batch, | |
| num_beams=10, | |
| num_return_sequences=num_return_sequences, | |
| temperature=1.0, | |
| repetition_penalty=2.8, | |
| length_penalty=1.2, | |
| max_length=80, | |
| min_length=5, | |
| no_repeat_ngram_size=3 | |
| ) | |
| tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) | |
| return tgt_text[0] | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Process the incoming request and generate paraphrased text. | |
| :param data: Request payload containing input text | |
| :return: Paraphrased text | |
| """ | |
| # Extract input text from the payload | |
| inputs = data.pop("inputs", data) | |
| # If input is not a string, raise an error | |
| if not isinstance(inputs, str): | |
| raise ValueError("Input must be a string") | |
| # Split text into paragraphs | |
| paragraphs = self.split_into_paragraphs(inputs) | |
| paraphrased_paragraphs = [] | |
| # Process each paragraph | |
| for paragraph in paragraphs: | |
| sentences = self.split_into_sentences(paragraph) | |
| paraphrased_sentences = [] | |
| for sentence in sentences: | |
| # Skip very short sentences | |
| if len(sentence.split()) < 3: | |
| paraphrased_sentences.append(sentence) | |
| continue | |
| try: | |
| # Paraphrase the sentence | |
| paraphrased = self.get_response(sentence) | |
| # Avoid unwanted paraphrases | |
| if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): | |
| paraphrased_sentences.append(paraphrased) | |
| else: | |
| paraphrased_sentences.append(sentence) | |
| except Exception as e: | |
| print(f"Error processing sentence: {e}") | |
| paraphrased_sentences.append(sentence) | |
| # Join sentences back into a paragraph | |
| paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) | |
| # Join paragraphs back into text | |
| return {"outputs": '\n\n'.join(paraphrased_paragraphs)} |
Xet Storage Details
- Size:
- 4.23 kB
- Xet hash:
- a3468c579efb3ee4ca2951f96db117138432ce0102a6af323b4703f7e53ecfba
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.