import os from loguru import logger import json import tempfile from typing import List, Dict, Any, Annotated, Optional from langchain_openai import AzureChatOpenAI from langchain.prompts import ChatPromptTemplate from langchain_text_splitters import RecursiveCharacterTextSplitter from models import Grammar, Error import docx from rich.table import Table from rich.box import ROUNDED import re import pandas as pd import asyncio from concurrent.futures import ThreadPoolExecutor from langchain_google_genai.chat_models import ChatGoogleGenerativeAI from pydantic import BaseModel, Field from langchain_google_genai import ChatGoogleGenerativeAI llm = ChatGoogleGenerativeAI( model="gemini-2.0-flash-001", temperature=0, max_tokens=None, timeout=None, max_retries=2, # other params... ) # Get Azure OpenAI credentials from environment variables # AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") # AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT") # AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT") # AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION") # llm = AzureChatOpenAI( # temperature=0, # api_key=AZURE_OPENAI_API_KEY, # azure_endpoint=AZURE_OPENAI_ENDPOINT, # azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME, # api_version=AZURE_OPENAI_API_VERSION, # ) # Constants for text splitting CHUNK_SIZE = 1000 # Approximate characters per page CHUNK_OVERLAP = 0 # Overlap between chunks to maintain context # Common tech terms and proper nouns that should not be flagged as errors DEFAULT_PROPER_NOUNS = """ API, APIs, HTML, CSS, JavaScript, TypeScript, Python, Java, C++, SQL, NoSQL, MongoDB, PostgreSQL, MySQL, Redis, Docker, Kubernetes, AWS, Azure, GCP, HTTP, HTTPS, REST, GraphQL, JSON, XML, YAML, React, Angular, Vue, Node.js, Express, Flask, Django, Spring, TensorFlow, PyTorch, Scikit-learn, npm, pip, GitHub, GitLab, Bitbucket, Jira, Confluence, Slack, OAuth, JWT, SSL, TLS """ from typing import TypedDict, Dict def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]: """ Check grammar for a question and return corrected text. """ system_message = """ You are a spellchecker for a question and answer pair. Related to IT and programming. You will be given a question and answer pair. You will need to check the grammar of the question and answer pair. You will need to return the corrected question and answer pair in a dictionary. If any of the fields are not errors, you should return the original value. Output should be a dictionary with same keys as the input dictionary. """ input_message = """ Here are input dictionary: {data} """ prompt = ChatPromptTemplate.from_messages( [("system", system_message), ("user", input_message)] ) class GrammarResult(BaseModel): output: Dict[str, str] = Field( ..., description="A dictionary with same keys as the input dictionary." ) wrong_locations: Optional[str] = Field( None, description="point out errors briefly. Leave blank if there are no errors." ) chain = prompt | llm.with_structured_output(GrammarResult) result = chain.invoke({"data": data}) return result def check_grammar_qa( qa_dict: Dict[str, Any], proper_nouns: str = DEFAULT_PROPER_NOUNS ) -> Dict[str, str]: """ Check grammar for a QA dictionary and return corrected text. Args: qa_dict: Dictionary containing question and answer options proper_nouns: A string of proper nouns to preserve Returns: Dictionary with corrected text for each field """ corrected_dict = {} # Only process the Question and Answer Options A-D if "Question" in qa_dict and not pd.isna(qa_dict["Question"]): corrected_dict["Question"] = qa_dict["Question"] # Process answer options for option in [ "Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D", ]: if option in qa_dict and not pd.isna(qa_dict[option]): corrected_dict[option] = qa_dict[option] return check_grammar_question(corrected_dict) def extract_text_from_docx(file_content: bytes) -> str: """ Extract text from a .docx file. Args: file_content: The bytes content of the .docx file Returns: The extracted text as a string """ try: # Create a temporary file to save the content with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as temp_file: temp_file.write(file_content) temp_file_path = temp_file.name # Open the temporary docx file and extract text doc = docx.Document(temp_file_path) full_text = [] for para in doc.paragraphs: full_text.append(para.text) # Clean up the temporary file os.unlink(temp_file_path) return "\n".join(full_text) except Exception as e: logger.error(f"Error extracting text from docx: {str(e)}") raise Exception(f"Failed to extract text from docx: {str(e)}") def extract_text_from_file(file_content: bytes, file_extension: str) -> str: """ Extract text from a file based on its extension. Args: file_content: The bytes content of the file file_extension: The file extension (.txt, .docx, etc.) Returns: The extracted text as a string """ if file_extension.lower() == ".txt": # For txt files, simply decode the content return file_content.decode("utf-8", errors="replace") elif file_extension.lower() == ".docx": # For docx files, use the docx extraction function return extract_text_from_docx(file_content) else: raise ValueError(f"Unsupported file extension: {file_extension}") class SentenceBasedTextSplitter(RecursiveCharacterTextSplitter): def __init__(self, chunk_size: int, chunk_overlap: int = 0): super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self.chunk_size = chunk_size def split_text(self, text: str): sentence_endings = re.compile(r"(?<=[.!?])\s+") sentences = sentence_endings.split(text) chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) <= self.chunk_size: current_chunk += sentence + " " else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " # Ensure the last chunk includes the remaining sentence if it exists if current_chunk: chunks.append(current_chunk.strip()) return chunks def split_text(text: str) -> List[str]: """ Split text into chunks of appropriate size for processing. Args: text: The full text to split Returns: A list of text chunks """ # splitter = RecursiveCharacterTextSplitter( # chunk_size=CHUNK_SIZE, # chunk_overlap=CHUNK_OVERLAP, # length_function=len, # is_separator_regex=False, # ) splitter = SentenceBasedTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP ) chunks = splitter.split_text(text) logger.debug(f"Split text into {len(chunks)} chunks") return chunks def create_grammar_prompt(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> str: """ Create a grammar checking prompt for the given text with proper nouns. Args: text: The text to check for grammar issues proper_nouns: A string of proper nouns to preserve Returns: A formatted prompt string """ return f""" Rewrite the provided text to be clear and grammatically correct while preserving technical accuracy. Focus on: 1. Correcting spelling, punctuation, and grammar errors 2. Maintaining technical terminology and code snippets 3. Ensuring consistent tense, voice, and formatting 4. Clarifying function descriptions, parameters, and return values 5. Proper use of capitalization, acronyms, and abbreviations 6. Improving clarity and conciseness 7. Respect markdown and code formatting such as underscores, asterisks, backticks, code blocks, and links 8. Ensure proper nouns and acronyms are correctly spelled and capitalized Here's a list of proper nouns and technical terms you should preserve: {proper_nouns} Preserve code-specific formatting and syntax. Prioritize original text if unsure about technical terms. Make sure when you show the before vs after text, include a larger phrase or sentence for context. In the response: - For 'spelling', 'punctuation', and 'grammar' keys: Provide only changed items with original text, corrected text, and explanation. Ensure that the original text is actually referenced from the given text below: {text} """ def process_api_response(content: str) -> Dict[str, List[Dict[str, str]]]: """ Process the API response to extract the JSON result. Args: content: The API response content Returns: A dictionary with grammar error categories """ # Try to find JSON pattern json_start = content.find("{") json_end = content.rfind("}") + 1 if json_start == -1 or json_end == 0: logger.error(f"Could not find JSON in response: {content}") raise ValueError("API response did not contain valid JSON") json_str = content[json_start:json_end] # Parse the JSON try: result = json.loads(json_str) except json.JSONDecodeError as je: logger.error(f"JSON decode error: {str(je)}") logger.error(f"JSON string was: {json_str}") # Create a default structure result = {"spelling": [], "punctuation": [], "grammar": []} return result def merge_grammar_results( results: List[Dict[str, List[Dict[str, str]]]], ) -> Dict[str, List[Dict[str, str]]]: """ Merge multiple grammar check results into a single result. Args: results: A list of grammar check results Returns: A merged grammar check result """ merged = {"spelling": [], "punctuation": [], "grammar": []} for result in results: for category in ["spelling", "punctuation", "grammar"]: if category in result: merged[category].extend(result[category]) return merged def validate_corrections( result: Dict[str, List[Dict[str, str]]], ) -> Dict[str, List[Dict[str, str]]]: """ Validate grammar corrections to ensure they're meaningful. Args: result: The grammar check result Returns: Validated grammar check result """ validated = {"spelling": [], "punctuation": [], "grammar": []} for category in ["spelling", "punctuation", "grammar"]: for error in result.get(category, []): # Skip if before and after are the same if error["before"] == error["after"]: continue # Skip if only whitespace changes if error["before"].strip() == error["after"].strip(): continue validated[category].append(error) return validated def apply_corrections(original_text: str, errors: List[Error]) -> str: """ Apply all grammar corrections to the original text. Args: original_text: The original text with errors errors: List of Error objects with before/after corrections Returns: Fully corrected text """ # Process individual errors one at a time # Make a copy of the original text corrected = original_text # First, find the position of each error in the original text error_positions = [] for error in errors: pos = corrected.find(error.before) if pos != -1: error_positions.append((pos, error)) # Sort by position in descending order (to replace from end to start) # This way, earlier replacements don't affect positions of later ones error_positions.sort(key=lambda x: x[0], reverse=True) # Apply each correction for pos, error in error_positions: corrected = corrected[:pos] + error.after + corrected[pos + len(error.before) :] return corrected def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Grammar: """ Check the grammar of the given text using LangChain and Azure OpenAI. Args: text: The text to check for grammar issues proper_nouns: A string of proper nouns to preserve Returns: Grammar object containing categorized errors """ try: # Split text into chunks if it's too long chunks = split_text(text) # Initialize LangChain with Azure OpenAI # logger.debug( # f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}" # ) # Create system message for JSON format system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON. The JSON object must use the schema that has 'spelling', 'punctuation', and 'grammar' keys, each with a list of objects containing 'before', 'after', and 'explanation'. It is strictly imperative that you return as JSON. DO NOT return any other characters other than valid JSON as your response.""" # Create a prompt template and chain using the pipe syntax prompt = ChatPromptTemplate.from_messages( [("system", system_message), ("user", "{prompt}")] ) chain = prompt | llm # Process each chunk in a batch logger.debug(f"Processing {len(chunks)} chunks in batch...") if len(chunks) == 1: # For single chunks, just use invoke directly prompt_text = create_grammar_prompt(chunks[0], proper_nouns) response = chain.invoke({"prompt": prompt_text}) content = response.content result = process_api_response(content) else: # For multiple chunks, use batch processing prompt_batch = [ {"prompt": create_grammar_prompt(chunk, proper_nouns)} for chunk in chunks ] responses = chain.batch(prompt_batch) logger.debug(f"Received {len(responses)} batch responses from API") # Process each response results = [] for response in responses: content = response.content result = process_api_response(content) results.append(result) # Merge the results result = merge_grammar_results(results) # Validate corrections to ensure they're meaningful validated_result = validate_corrections(result) # Create Error objects for each category spelling_errors = [Error(**err) for err in validated_result.get("spelling", [])] punctuation_errors = [ Error(**err) for err in validated_result.get("punctuation", []) ] grammar_errors = [Error(**err) for err in validated_result.get("grammar", [])] # Apply corrections to get fully corrected text corrected_text = apply_corrections( text, spelling_errors + punctuation_errors + grammar_errors ) # Return a Grammar object return Grammar( spelling=spelling_errors, punctuation=punctuation_errors, grammar=grammar_errors, file_path="", # Will be updated for file uploads corrected_text=corrected_text, # Add the corrected text ) except Exception as e: logger.error(f"Error checking grammar: {str(e)}") raise Exception(f"Failed to analyze text: {str(e)}") def check_grammar_from_file( file_content: bytes, filename: str, proper_nouns: str = DEFAULT_PROPER_NOUNS ) -> Grammar: """ Check grammar from an uploaded file. Args: file_content: The bytes content of the file filename: The name of the uploaded file proper_nouns: A string of proper nouns to preserve Returns: Grammar object containing categorized errors """ try: _, file_extension = os.path.splitext(filename) text = extract_text_from_file(file_content, file_extension) # Check grammar on the extracted text grammar_result = check_grammar(text, proper_nouns) # Update the file path grammar_result.file_path = filename return grammar_result except Exception as e: logger.error(f"Error checking grammar from file: {str(e)}") raise Exception(f"Failed to analyze file: {str(e)}") def display_results(response: Grammar, path: str = "", repo_link: str = "") -> int: """ Display the grammar check results using Rich. Args: response: The Grammar object with check results path: Path to the file that was checked repo_link: Optional repository link (for GitHub URLs) Returns: Total number of errors found """ # Replace local file path with GitHub URL if repo_link is provided if repo_link and response.file_path: # Use os.path.split to handle path separators correctly parts = os.path.normpath(response.file_path).split(os.path.sep) relative_path = os.path.basename(response.file_path) path = f"{repo_link.rstrip('/')}/blob/main/{relative_path}" elif path: # Use the provided path pass elif response.file_path: # Use the file path from the response path = response.file_path else: # Default text path = "Text input" # Print the file path total_errors = 0 # Display each error category for category in ["spelling", "punctuation", "grammar"]: table = Table(title=f"{category.capitalize()} Corrections", box=ROUNDED) table.add_column("Original", justify="left", style="bold red") table.add_column("Corrected", justify="left", style="bold green") table.add_column("Explanation", justify="left", style="italic") errors = getattr(response, category) for error in errors: if error.before != error.after: table.add_row(error.before, error.after, error.explanation) table.add_row("", "", "") # Add an empty row for spacing total_errors += 1 if errors: print(table) else: no_errors_msg = f"No {category} errors found." return total_errors def check_grammar_questions_batch(questions: List[Dict[str, Any]], batch_size: int = 5) -> List[Dict[str, Any]]: """ Process multiple questions in batches for grammar checking. Args: questions: List of question dictionaries to process batch_size: Number of questions to process in each batch Returns: List of processed question dictionaries with grammar corrections """ system_message = """ You are a spellchecker for a batch of questions and answers related to IT and programming. You will be given multiple question and answer pairs. Check the grammar of each question and answer pair. Return a list of dictionaries with the same structure as the input, but with corrected text. If any fields have no errors, return the original value. """ def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: input_message = """ Here is a question to check: {data} """ prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", input_message) ]) class BatchGrammarResult(BaseModel): output: Dict[str, Any] = Field( ..., description="Dictionary with corrected text" ) wrong_locations: str = Field( ..., description="Error descriptions for the question" ) chain = prompt | llm.with_structured_output(BatchGrammarResult) # Create prompts for each question in the batch prompts = [{"data": question} for question in batch] logger.info(f"prompt {prompts}") # Process all questions in parallel using batch results = chain.batch(prompts) # Extract and combine results processed_results = [] for result in results: result = result.dict() processed_results.append({ **result["output"], "wrong_locations": result["wrong_locations"] }) return processed_results # Preprocess questions to include only relevant fields preprocessed_questions = [] for qa_dict in questions: processed_dict = {} if "Question" in qa_dict and not pd.isna(qa_dict["Question"]): processed_dict["Question"] = qa_dict["Question"] for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]: if option in qa_dict and not pd.isna(qa_dict[option]): processed_dict[option] = qa_dict[option] # Keep original metadata processed_dict["No."] = qa_dict.get("No.") processed_dict["Training content"] = qa_dict.get("Training content") processed_dict["Answer"] = qa_dict.get("Answer") preprocessed_questions.append(processed_dict) # Process questions in batches results = [] total_batches = (len(preprocessed_questions) + batch_size - 1) // batch_size logger.info(f"Processing {len(preprocessed_questions)} questions in {total_batches} batches") for i in range(0, len(preprocessed_questions), batch_size): batch = preprocessed_questions[i:i + batch_size] batch_num = (i // batch_size) + 1 logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} questions") batch_results = process_batch(batch) results.extend(batch_results) return results def process_grammar_check(input_file: str, output_file: str, limit: Optional[int] = None) -> str: """ Process an Excel file with questions and answers, check grammar, and save the corrected data. Args: input_file (str): Path to the input Excel file output_file (str): Path to save the output Excel file limit (int, optional): Limit the number of records to process. If None, process all records. Returns: str: Path to the output file """ # Read the input file df = pd.read_excel(input_file, sheet_name="Sheet1") records = df.to_dict(orient="records") if limit is not None: records = records[:limit] # Process the records in batches processed_records = check_grammar_questions_batch(records,batch_size=30) # Create a DataFrame from the processed data and write to Excel output_df = pd.DataFrame(processed_records) output_df.to_excel(output_file, index=False) return output_file