import torch import gradio as gr import nltk from transformers import T5ForConditionalGeneration, T5Tokenizer, ElectraTokenizer, ElectraForTokenClassification import torch.nn as nn from tqdm import tqdm import numpy as np from huggingface_hub import hf_hub_download import re import difflib nltk.download('punkt') class T5WithGED(nn.Module): def __init__(self, model_path="Zlovoblachko/REAEC_GEC_2step_test", ged_model_path="Zlovoblachko/4tag-electra-grammar-error-detection"): super().__init__() self.t5 = T5ForConditionalGeneration.from_pretrained(model_path) self.t5_tokenizer = T5Tokenizer.from_pretrained(model_path) self.has_ged = False try: self.ged_encoder = self.t5.encoder self.gate = nn.Linear(2 * self.t5.config.d_model, 1) try: ged_components_path = hf_hub_download( repo_id=model_path, filename="ged_components.pt" ) ged_components = torch.load(ged_components_path, map_location=torch.device('cpu')) self.ged_encoder.load_state_dict(ged_components["ged_encoder"]) self.gate.load_state_dict(ged_components["gate"]) self.has_ged = True except Exception as e: print(f"Could not load GED components: {e}") except Exception as e: print(f"Error setting up GED integration: {e}") self.ged_model = None self.ged_tokenizer = None try: self.ged_tokenizer = ElectraTokenizer.from_pretrained(ged_model_path) self.ged_model = ElectraForTokenClassification.from_pretrained(ged_model_path) self.ged_model.eval() except Exception as e: print(f"Could not load GED model: {e}") def get_ged_predictions(self, text): """Get GED predictions for a sentence.""" if self.ged_model is None or self.ged_tokenizer is None: return None inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = self.ged_model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=2) token_predictions = predictions[0].cpu().numpy().tolist() tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) input_tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) token_pred_pairs = [] for i, (token, pred) in enumerate(zip(tokens, token_predictions)): if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: continue if pred == 0: tag = "C" elif pred == 1: tag = "R" elif pred == 2: tag = "M" elif pred == 3: tag = "U" else: tag = "C" token_pred_pairs.append((token, tag, i)) ged_tags = [pair[1] for pair in token_pred_pairs] error_spans = [] current_span = None for i, (token, tag, token_idx) in enumerate(token_pred_pairs): if tag in ["R", "M", "U"]: if current_span is None: current_span = { "start_idx": i, "error_type": tag, "tokens": [token], "token_indices": [token_idx] } elif current_span["error_type"] == tag: current_span["tokens"].append(token) current_span["token_indices"].append(token_idx) else: error_spans.append(current_span) current_span = { "start_idx": i, "error_type": tag, "tokens": [token], "token_indices": [token_idx] } else: if current_span is not None: error_spans.append(current_span) current_span = None if current_span is not None: error_spans.append(current_span) formatted_spans = [] for span in error_spans: span_tokens = span["tokens"] span_text = " ".join(span_tokens) error_type = span["error_type"] formatted_spans.append({ "text": span_text, "type": error_type, "tokens": span_tokens, "token_indices": span["token_indices"] }) return " ".join(ged_tags), formatted_spans, input_tokens def correct(self, text, use_ged=True, max_length=128): """Correct grammatical errors in text.""" inputs = self.t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) ged_tags = None error_spans = None if self.has_ged and use_ged and self.ged_model is not None: ged_info = self.get_ged_predictions(text) if ged_info is not None: ged_tags, error_spans, input_tokens = ged_info if ged_tags is None: output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length) corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) return corrected_text, None, None ged_inputs = self.t5_tokenizer(ged_tags, return_tensors="pt", truncation=True, max_length=max_length) src_encoder_outputs = self.t5.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True) ged_encoder_outputs = self.ged_encoder(input_ids=ged_inputs.input_ids, attention_mask=ged_inputs.attention_mask, return_dict=True) src_hidden_states = src_encoder_outputs.last_hidden_state ged_hidden_states = ged_encoder_outputs.last_hidden_state min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1)) combined = torch.cat([src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :]], dim=2) gate_scores = torch.sigmoid(self.gate(combined)) # formula: λ*src_hidden + (1-λ)*ged_hidden combined_hidden = (gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :]) src_encoder_outputs.last_hidden_state = combined_hidden output_ids = self.t5.generate(encoder_outputs=src_encoder_outputs, max_length=max_length) else: # debug: use usual t5 output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length) corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) return corrected_text, ged_tags, error_spans def find_differences(source, corrected): """Find differences between source and corrected text.""" diff = difflib.ndiff(source.split(), corrected.split()) changes = [] for i, s in enumerate(diff): if s.startswith('- '): changes.append({"type": "deletion", "text": s[2:], "position": i}) elif s.startswith('+ '): changes.append({"type": "addition", "text": s[2:], "position": i}) return changes def process_text(text, model): """Process input text by splitting into sentences and applying the model.""" if not text.strip(): return "Please enter some text." try: sentences = nltk.sent_tokenize(text) except LookupError: nltk.download('punkt_tab') sentences = nltk.sent_tokenize(text) results = [] for sentence in sentences: corrected, ged_tags, error_spans = model.correct(sentence) # Create result dictionary result = { "original": sentence, "corrected": corrected, "ged_tags": ged_tags, "error_spans": error_spans} results.append(result) # Generate HTML output with highlighted errors html_output = "
" for i, result in enumerate(results): html_output += f"
" # Original sentence with error spans highlighted original = result["original"] error_spans = result["error_spans"] if error_spans: # Convert the original sentence to HTML with highlighted spans html_output += "

Original sentence:

" # Sort spans by token index for proper display if error_spans: error_spans.sort(key=lambda x: x["token_indices"][0]) # Create a visualization of the original text with error spans marked_original = original replacements = [] for span in error_spans: error_type = span["type"] span_text = span["text"] # Set color based on error type if error_type == "R": color = "#FFCCCC" # Light red for replacement label = "Replace" elif error_type == "M": color = "#CCFFCC" # Light green for missing label = "Missing" elif error_type == "U": color = "#CCCCFF" # Light blue for unnecessary label = "Unnecessary" # Find the span in the original text pattern = re.escape(span_text.replace(" ", r"\s+")) matches = list(re.finditer(pattern, marked_original, re.IGNORECASE)) for match in matches: replacements.append(( match.start(), match.end(), f"{match.group(0)}" )) # Apply replacements from end to start to avoid index shifting replacements.sort(key=lambda x: x[0], reverse=True) for start, end, replacement in replacements: marked_original = marked_original[:start] + replacement + marked_original[end:] html_output += f"

{marked_original}

" else: html_output += f"

Original sentence: {original}

" # Corrected sentence html_output += f"

Corrected: {result['corrected']}

" # Find differences for additional visualization changes = find_differences(original, result["corrected"]) if changes: html_output += "

Changes:

" html_output += "
" html_output += "
" return html_output def create_gradio_app(): model = T5WithGED("Zlovoblachko/REAEC_GEC_2step_test", "Zlovoblachko/4tag-electra-grammar-error-detection") iface = gr.Interface( fn=lambda text: process_text(text, model), inputs=gr.Textbox( lines=5, placeholder="Enter text to correct grammatical errors...", label="Input Text" ), outputs=gr.HTML(label="Corrected Text"), title="Grammar Error Correction with Detection", description=""" This app corrects grammatical errors in text using an ensemble of models: 1. An ELECTRA-based Grammatical Error Detection (GED) model identifies error spans 2. A T5-based Grammatical Error Correction (GEC) model corrects the errors Enter your text and see the corrections with highlighted error spans: - Red: Replacement needed - Green: Missing word - Blue: Unnecessary word """, examples=[ ["First of all, we can see increasing tendency of overweighting during the hole period."], ["Food products were mostly transportaded by the road."], ["I have went to the store yesterday. She dont like to study for exams."], ["The company have announced a new policy. I am living in London since 2010."], ["He didnt studied for the test. They was at the party last night."], ["The chart illustrates the number in percents of overweight children in Canada throughout a 20-years period from 1985 to 2005, while the table demonstrates the percentage of children doing sport exercises regulary over the period from 1990 to 2005. Overall, it can be seen that despite the fact that the number of boys and girls performing exercises has grown considerably by the end of the period, percent of overweight children has increased too. According to the graph, boys are more likely to have extra weight in period of 2000-2005, a quater of them had problems with weight in 2005. Girls were going ahead of boys in 1985-1990, then they maintained the same level in 1995, but then the number of outweight boys went up more rapidly. The table allows to see that interest in physical activity has grown by more than 25% both within boys and girls by 2005."] ], allow_flagging="never" ) return iface iface = create_gradio_app() iface.launch()