import torch import gradio as gr import nltk from transformers import T5ForConditionalGeneration, T5Tokenizer, ElectraTokenizer, ElectraForTokenClassification import torch.nn as nn from tqdm import tqdm import numpy as np from huggingface_hub import hf_hub_download import re import difflib nltk.download('punkt') class T5WithGED(nn.Module): def __init__(self, model_path="Zlovoblachko/REAEC_GEC_2step_test", ged_model_path="Zlovoblachko/4tag-electra-grammar-error-detection"): super().__init__() self.t5 = T5ForConditionalGeneration.from_pretrained(model_path) self.t5_tokenizer = T5Tokenizer.from_pretrained(model_path) self.has_ged = False try: self.ged_encoder = self.t5.encoder self.gate = nn.Linear(2 * self.t5.config.d_model, 1) try: ged_components_path = hf_hub_download( repo_id=model_path, filename="ged_components.pt" ) ged_components = torch.load(ged_components_path, map_location=torch.device('cpu')) self.ged_encoder.load_state_dict(ged_components["ged_encoder"]) self.gate.load_state_dict(ged_components["gate"]) self.has_ged = True except Exception as e: print(f"Could not load GED components: {e}") except Exception as e: print(f"Error setting up GED integration: {e}") self.ged_model = None self.ged_tokenizer = None try: self.ged_tokenizer = ElectraTokenizer.from_pretrained(ged_model_path) self.ged_model = ElectraForTokenClassification.from_pretrained(ged_model_path) self.ged_model.eval() except Exception as e: print(f"Could not load GED model: {e}") def get_ged_predictions(self, text): """Get GED predictions for a sentence.""" if self.ged_model is None or self.ged_tokenizer is None: return None inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = self.ged_model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=2) token_predictions = predictions[0].cpu().numpy().tolist() tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) input_tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) token_pred_pairs = [] for i, (token, pred) in enumerate(zip(tokens, token_predictions)): if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: continue if pred == 0: tag = "C" elif pred == 1: tag = "R" elif pred == 2: tag = "M" elif pred == 3: tag = "U" else: tag = "C" token_pred_pairs.append((token, tag, i)) ged_tags = [pair[1] for pair in token_pred_pairs] error_spans = [] current_span = None for i, (token, tag, token_idx) in enumerate(token_pred_pairs): if tag in ["R", "M", "U"]: if current_span is None: current_span = { "start_idx": i, "error_type": tag, "tokens": [token], "token_indices": [token_idx] } elif current_span["error_type"] == tag: current_span["tokens"].append(token) current_span["token_indices"].append(token_idx) else: error_spans.append(current_span) current_span = { "start_idx": i, "error_type": tag, "tokens": [token], "token_indices": [token_idx] } else: if current_span is not None: error_spans.append(current_span) current_span = None if current_span is not None: error_spans.append(current_span) formatted_spans = [] for span in error_spans: span_tokens = span["tokens"] span_text = " ".join(span_tokens) error_type = span["error_type"] formatted_spans.append({ "text": span_text, "type": error_type, "tokens": span_tokens, "token_indices": span["token_indices"] }) return " ".join(ged_tags), formatted_spans, input_tokens def correct(self, text, use_ged=True, max_length=128): """Correct grammatical errors in text.""" inputs = self.t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) ged_tags = None error_spans = None if self.has_ged and use_ged and self.ged_model is not None: ged_info = self.get_ged_predictions(text) if ged_info is not None: ged_tags, error_spans, input_tokens = ged_info if ged_tags is None: output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length) corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) return corrected_text, None, None ged_inputs = self.t5_tokenizer(ged_tags, return_tensors="pt", truncation=True, max_length=max_length) src_encoder_outputs = self.t5.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True) ged_encoder_outputs = self.ged_encoder(input_ids=ged_inputs.input_ids, attention_mask=ged_inputs.attention_mask, return_dict=True) src_hidden_states = src_encoder_outputs.last_hidden_state ged_hidden_states = ged_encoder_outputs.last_hidden_state min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1)) combined = torch.cat([src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :]], dim=2) gate_scores = torch.sigmoid(self.gate(combined)) # formula: λ*src_hidden + (1-λ)*ged_hidden combined_hidden = (gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :]) src_encoder_outputs.last_hidden_state = combined_hidden output_ids = self.t5.generate(encoder_outputs=src_encoder_outputs, max_length=max_length) else: # debug: use usual t5 output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length) corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) return corrected_text, ged_tags, error_spans def find_differences(source, corrected): """Find differences between source and corrected text.""" diff = difflib.ndiff(source.split(), corrected.split()) changes = [] for i, s in enumerate(diff): if s.startswith('- '): changes.append({"type": "deletion", "text": s[2:], "position": i}) elif s.startswith('+ '): changes.append({"type": "addition", "text": s[2:], "position": i}) return changes def process_text(text, model): """Process input text by splitting into sentences and applying the model.""" if not text.strip(): return "Please enter some text." try: sentences = nltk.sent_tokenize(text) except LookupError: nltk.download('punkt_tab') sentences = nltk.sent_tokenize(text) results = [] for sentence in sentences: corrected, ged_tags, error_spans = model.correct(sentence) # Create result dictionary result = { "original": sentence, "corrected": corrected, "ged_tags": ged_tags, "error_spans": error_spans} results.append(result) # Generate HTML output with highlighted errors html_output = "
Original sentence:
" # Sort spans by token index for proper display if error_spans: error_spans.sort(key=lambda x: x["token_indices"][0]) # Create a visualization of the original text with error spans marked_original = original replacements = [] for span in error_spans: error_type = span["type"] span_text = span["text"] # Set color based on error type if error_type == "R": color = "#FFCCCC" # Light red for replacement label = "Replace" elif error_type == "M": color = "#CCFFCC" # Light green for missing label = "Missing" elif error_type == "U": color = "#CCCCFF" # Light blue for unnecessary label = "Unnecessary" # Find the span in the original text pattern = re.escape(span_text.replace(" ", r"\s+")) matches = list(re.finditer(pattern, marked_original, re.IGNORECASE)) for match in matches: replacements.append(( match.start(), match.end(), f"{match.group(0)}" )) # Apply replacements from end to start to avoid index shifting replacements.sort(key=lambda x: x[0], reverse=True) for start, end, replacement in replacements: marked_original = marked_original[:start] + replacement + marked_original[end:] html_output += f"{marked_original}
" else: html_output += f"Original sentence: {original}
" # Corrected sentence html_output += f"Corrected: {result['corrected']}
" # Find differences for additional visualization changes = find_differences(original, result["corrected"]) if changes: html_output += "Changes: