| | from transformers import AutoTokenizer, T5ForConditionalGeneration |
| | import torch |
| | from difflib import SequenceMatcher |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | model_name = path if path else "grammarly/coedit-large" |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = T5ForConditionalGeneration.from_pretrained(model_name) |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| |
|
| | def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0): |
| | |
| | prefix = "Fix the grammar: " |
| | sentences_with_prefix = [prefix + s for s in sentences] |
| | |
| | inputs = self.tokenizer( |
| | sentences_with_prefix, |
| | padding=True, |
| | truncation=True, |
| | max_length=512, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_length=512, |
| | num_beams=5, |
| | temperature=temperature, |
| | num_return_sequences=num_return_sequences, |
| | early_stopping=True |
| | ) |
| | |
| | decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | if num_return_sequences > 1: |
| | grouped = [ |
| | decoded[i * num_return_sequences:(i + 1) * num_return_sequences] |
| | for i in range(len(sentences)) |
| | ] |
| | return grouped |
| | else: |
| | return decoded |
| | |
| | def compute_changes(self, original, enhanced): |
| | changes = [] |
| | matcher = SequenceMatcher(None, original, enhanced) |
| |
|
| | for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
| | if tag in ("replace", "insert", "delete"): |
| | original_phrase = original[i1:i2] |
| | new_phrase = enhanced[j1:j2] |
| | changes.append({ |
| | "original_phrase": original_phrase, |
| | "new_phrase": new_phrase, |
| | "char_start": i1, |
| | "char_end": i2, |
| | "token_start": None, |
| | "token_end": None, |
| | "explanation": f"{tag} change", |
| | "error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "", |
| | "tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else "" |
| | }) |
| | return changes |
| |
|
| |
|
| | def __call__(self, inputs): |
| | |
| | |
| | |
| | if isinstance(inputs, list): |
| | sentences = inputs |
| | parameters = {} |
| | elif isinstance(inputs, dict): |
| | |
| | sentences = inputs.get("inputs", []) |
| | |
| | if isinstance(sentences, str): |
| | sentences = [sentences] |
| | parameters = inputs.get("parameters", {}) |
| | else: |
| | return { |
| | "success": False, |
| | "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys." |
| | } |
| |
|
| | |
| | num_return_sequences = parameters.get("num_return_sequences", 1) |
| | temperature = parameters.get("temperature", 1.0) |
| |
|
| | if not sentences: |
| | return { |
| | "success": False, |
| | "error": "No sentences provided." |
| | } |
| |
|
| | try: |
| | paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature) |
| | results = [] |
| | |
| | if num_return_sequences > 1: |
| | |
| | for i, orig in enumerate(sentences): |
| | for cand in paraphrased[i]: |
| | results.append({ |
| | "original_sentence": orig, |
| | "enhanced_sentence": cand, |
| | "changes": self.compute_changes(orig, cand) |
| | }) |
| | else: |
| | |
| | for orig, cand in zip(sentences, paraphrased): |
| | results.append({ |
| | "original_sentence": orig, |
| | "enhanced_sentence": cand, |
| | "changes": self.compute_changes(orig, cand) |
| | }) |
| | |
| | return { |
| | "success": True, |
| | "results": results, |
| | "sentences_count": len(sentences), |
| | "processed_count": len(results), |
| | "skipped_count": 0, |
| | "error_count": 0 |
| | } |
| | except Exception as e: |
| | return { |
| | "success": False, |
| | "error": str(e), |
| | "sentences_count": len(sentences), |
| | "processed_count": 0, |
| | "skipped_count": 0, |
| | "error_count": 1 |
| | } |