Spaces:
Sleeping
Sleeping
| import os | |
| from loguru import logger | |
| import json | |
| import tempfile | |
| from typing import List, Dict, Any, Annotated, Optional | |
| from langchain_openai import AzureChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from models import Grammar, Error | |
| import docx | |
| from rich.table import Table | |
| from rich.box import ROUNDED | |
| import re | |
| import pandas as pd | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from langchain_google_genai.chat_models import ChatGoogleGenerativeAI | |
| from pydantic import BaseModel, Field | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash-001", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| # other params... | |
| ) | |
| # Get Azure OpenAI credentials from environment variables | |
| # AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") | |
| # AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT") | |
| # AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT") | |
| # AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION") | |
| # llm = AzureChatOpenAI( | |
| # temperature=0, | |
| # api_key=AZURE_OPENAI_API_KEY, | |
| # azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
| # azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME, | |
| # api_version=AZURE_OPENAI_API_VERSION, | |
| # ) | |
| # Constants for text splitting | |
| CHUNK_SIZE = 1000 # Approximate characters per page | |
| CHUNK_OVERLAP = 0 # Overlap between chunks to maintain context | |
| # Common tech terms and proper nouns that should not be flagged as errors | |
| DEFAULT_PROPER_NOUNS = """ | |
| API, APIs, HTML, CSS, JavaScript, TypeScript, Python, Java, C++, SQL, NoSQL, | |
| MongoDB, PostgreSQL, MySQL, Redis, Docker, Kubernetes, AWS, Azure, GCP, | |
| HTTP, HTTPS, REST, GraphQL, JSON, XML, YAML, React, Angular, Vue, Node.js, | |
| Express, Flask, Django, Spring, TensorFlow, PyTorch, Scikit-learn, npm, pip, | |
| GitHub, GitLab, Bitbucket, Jira, Confluence, Slack, OAuth, JWT, SSL, TLS | |
| """ | |
| from typing import TypedDict, Dict | |
| def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]: | |
| """ | |
| Check grammar for a question and return corrected text. | |
| """ | |
| system_message = """ | |
| You are a spellchecker for a question and answer pair. Related to IT and programming. | |
| You will be given a question and answer pair. | |
| You will need to check the grammar of the question and answer pair. | |
| You will need to return the corrected question and answer pair in a dictionary. If any of the fields are not errors, you should return the original value. | |
| Output should be a dictionary with same keys as the input dictionary. | |
| """ | |
| input_message = """ | |
| Here are input dictionary: | |
| {data} | |
| """ | |
| prompt = ChatPromptTemplate.from_messages( | |
| [("system", system_message), ("user", input_message)] | |
| ) | |
| class GrammarResult(BaseModel): | |
| output: Dict[str, str] = Field( | |
| ..., description="A dictionary with same keys as the input dictionary." | |
| ) | |
| wrong_locations: Optional[str] = Field( | |
| None, description="point out errors briefly. Leave blank if there are no errors." | |
| ) | |
| chain = prompt | llm.with_structured_output(GrammarResult) | |
| result = chain.invoke({"data": data}) | |
| return result | |
| def check_grammar_qa( | |
| qa_dict: Dict[str, Any], proper_nouns: str = DEFAULT_PROPER_NOUNS | |
| ) -> Dict[str, str]: | |
| """ | |
| Check grammar for a QA dictionary and return corrected text. | |
| Args: | |
| qa_dict: Dictionary containing question and answer options | |
| proper_nouns: A string of proper nouns to preserve | |
| Returns: | |
| Dictionary with corrected text for each field | |
| """ | |
| corrected_dict = {} | |
| # Only process the Question and Answer Options A-D | |
| if "Question" in qa_dict and not pd.isna(qa_dict["Question"]): | |
| corrected_dict["Question"] = qa_dict["Question"] | |
| # Process answer options | |
| for option in [ | |
| "Answer Option A", | |
| "Answer Option B", | |
| "Answer Option C", | |
| "Answer Option D", | |
| ]: | |
| if option in qa_dict and not pd.isna(qa_dict[option]): | |
| corrected_dict[option] = qa_dict[option] | |
| return check_grammar_question(corrected_dict) | |
| def extract_text_from_docx(file_content: bytes) -> str: | |
| """ | |
| Extract text from a .docx file. | |
| Args: | |
| file_content: The bytes content of the .docx file | |
| Returns: | |
| The extracted text as a string | |
| """ | |
| try: | |
| # Create a temporary file to save the content | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as temp_file: | |
| temp_file.write(file_content) | |
| temp_file_path = temp_file.name | |
| # Open the temporary docx file and extract text | |
| doc = docx.Document(temp_file_path) | |
| full_text = [] | |
| for para in doc.paragraphs: | |
| full_text.append(para.text) | |
| # Clean up the temporary file | |
| os.unlink(temp_file_path) | |
| return "\n".join(full_text) | |
| except Exception as e: | |
| logger.error(f"Error extracting text from docx: {str(e)}") | |
| raise Exception(f"Failed to extract text from docx: {str(e)}") | |
| def extract_text_from_file(file_content: bytes, file_extension: str) -> str: | |
| """ | |
| Extract text from a file based on its extension. | |
| Args: | |
| file_content: The bytes content of the file | |
| file_extension: The file extension (.txt, .docx, etc.) | |
| Returns: | |
| The extracted text as a string | |
| """ | |
| if file_extension.lower() == ".txt": | |
| # For txt files, simply decode the content | |
| return file_content.decode("utf-8", errors="replace") | |
| elif file_extension.lower() == ".docx": | |
| # For docx files, use the docx extraction function | |
| return extract_text_from_docx(file_content) | |
| else: | |
| raise ValueError(f"Unsupported file extension: {file_extension}") | |
| class SentenceBasedTextSplitter(RecursiveCharacterTextSplitter): | |
| def __init__(self, chunk_size: int, chunk_overlap: int = 0): | |
| super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| self.chunk_size = chunk_size | |
| def split_text(self, text: str): | |
| sentence_endings = re.compile(r"(?<=[.!?])\s+") | |
| sentences = sentence_endings.split(text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= self.chunk_size: | |
| current_chunk += sentence + " " | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence + " " | |
| # Ensure the last chunk includes the remaining sentence if it exists | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def split_text(text: str) -> List[str]: | |
| """ | |
| Split text into chunks of appropriate size for processing. | |
| Args: | |
| text: The full text to split | |
| Returns: | |
| A list of text chunks | |
| """ | |
| # splitter = RecursiveCharacterTextSplitter( | |
| # chunk_size=CHUNK_SIZE, | |
| # chunk_overlap=CHUNK_OVERLAP, | |
| # length_function=len, | |
| # is_separator_regex=False, | |
| # ) | |
| splitter = SentenceBasedTextSplitter( | |
| chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP | |
| ) | |
| chunks = splitter.split_text(text) | |
| logger.debug(f"Split text into {len(chunks)} chunks") | |
| return chunks | |
| def create_grammar_prompt(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> str: | |
| """ | |
| Create a grammar checking prompt for the given text with proper nouns. | |
| Args: | |
| text: The text to check for grammar issues | |
| proper_nouns: A string of proper nouns to preserve | |
| Returns: | |
| A formatted prompt string | |
| """ | |
| return f""" | |
| Rewrite the provided text to be clear and grammatically correct while preserving technical accuracy. Focus on: | |
| 1. Correcting spelling, punctuation, and grammar errors | |
| 2. Maintaining technical terminology and code snippets | |
| 3. Ensuring consistent tense, voice, and formatting | |
| 4. Clarifying function descriptions, parameters, and return values | |
| 5. Proper use of capitalization, acronyms, and abbreviations | |
| 6. Improving clarity and conciseness | |
| 7. Respect markdown and code formatting such as underscores, asterisks, backticks, code blocks, and links | |
| 8. Ensure proper nouns and acronyms are correctly spelled and capitalized | |
| Here's a list of proper nouns and technical terms you should preserve: | |
| {proper_nouns} | |
| Preserve code-specific formatting and syntax. Prioritize original text if unsure about technical terms. | |
| Make sure when you show the before vs after text, include a larger phrase or sentence for context. | |
| In the response: | |
| - For 'spelling', 'punctuation', and 'grammar' keys: Provide only changed items with original text, corrected text, and explanation. | |
| Ensure that the original text is actually referenced from the given text below: | |
| {text} | |
| """ | |
| def process_api_response(content: str) -> Dict[str, List[Dict[str, str]]]: | |
| """ | |
| Process the API response to extract the JSON result. | |
| Args: | |
| content: The API response content | |
| Returns: | |
| A dictionary with grammar error categories | |
| """ | |
| # Try to find JSON pattern | |
| json_start = content.find("{") | |
| json_end = content.rfind("}") + 1 | |
| if json_start == -1 or json_end == 0: | |
| logger.error(f"Could not find JSON in response: {content}") | |
| raise ValueError("API response did not contain valid JSON") | |
| json_str = content[json_start:json_end] | |
| # Parse the JSON | |
| try: | |
| result = json.loads(json_str) | |
| except json.JSONDecodeError as je: | |
| logger.error(f"JSON decode error: {str(je)}") | |
| logger.error(f"JSON string was: {json_str}") | |
| # Create a default structure | |
| result = {"spelling": [], "punctuation": [], "grammar": []} | |
| return result | |
| def merge_grammar_results( | |
| results: List[Dict[str, List[Dict[str, str]]]], | |
| ) -> Dict[str, List[Dict[str, str]]]: | |
| """ | |
| Merge multiple grammar check results into a single result. | |
| Args: | |
| results: A list of grammar check results | |
| Returns: | |
| A merged grammar check result | |
| """ | |
| merged = {"spelling": [], "punctuation": [], "grammar": []} | |
| for result in results: | |
| for category in ["spelling", "punctuation", "grammar"]: | |
| if category in result: | |
| merged[category].extend(result[category]) | |
| return merged | |
| def validate_corrections( | |
| result: Dict[str, List[Dict[str, str]]], | |
| ) -> Dict[str, List[Dict[str, str]]]: | |
| """ | |
| Validate grammar corrections to ensure they're meaningful. | |
| Args: | |
| result: The grammar check result | |
| Returns: | |
| Validated grammar check result | |
| """ | |
| validated = {"spelling": [], "punctuation": [], "grammar": []} | |
| for category in ["spelling", "punctuation", "grammar"]: | |
| for error in result.get(category, []): | |
| # Skip if before and after are the same | |
| if error["before"] == error["after"]: | |
| continue | |
| # Skip if only whitespace changes | |
| if error["before"].strip() == error["after"].strip(): | |
| continue | |
| validated[category].append(error) | |
| return validated | |
| def apply_corrections(original_text: str, errors: List[Error]) -> str: | |
| """ | |
| Apply all grammar corrections to the original text. | |
| Args: | |
| original_text: The original text with errors | |
| errors: List of Error objects with before/after corrections | |
| Returns: | |
| Fully corrected text | |
| """ | |
| # Process individual errors one at a time | |
| # Make a copy of the original text | |
| corrected = original_text | |
| # First, find the position of each error in the original text | |
| error_positions = [] | |
| for error in errors: | |
| pos = corrected.find(error.before) | |
| if pos != -1: | |
| error_positions.append((pos, error)) | |
| # Sort by position in descending order (to replace from end to start) | |
| # This way, earlier replacements don't affect positions of later ones | |
| error_positions.sort(key=lambda x: x[0], reverse=True) | |
| # Apply each correction | |
| for pos, error in error_positions: | |
| corrected = corrected[:pos] + error.after + corrected[pos + len(error.before) :] | |
| return corrected | |
| def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Grammar: | |
| """ | |
| Check the grammar of the given text using LangChain and Azure OpenAI. | |
| Args: | |
| text: The text to check for grammar issues | |
| proper_nouns: A string of proper nouns to preserve | |
| Returns: | |
| Grammar object containing categorized errors | |
| """ | |
| try: | |
| # Split text into chunks if it's too long | |
| chunks = split_text(text) | |
| # Initialize LangChain with Azure OpenAI | |
| # logger.debug( | |
| # f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}" | |
| # ) | |
| # Create system message for JSON format | |
| system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON. | |
| The JSON object must use the schema that has 'spelling', 'punctuation', and 'grammar' keys, each with a list of objects containing 'before', 'after', and 'explanation'. | |
| It is strictly imperative that you return as JSON. DO NOT return any other characters other than valid JSON as your response.""" | |
| # Create a prompt template and chain using the pipe syntax | |
| prompt = ChatPromptTemplate.from_messages( | |
| [("system", system_message), ("user", "{prompt}")] | |
| ) | |
| chain = prompt | llm | |
| # Process each chunk in a batch | |
| logger.debug(f"Processing {len(chunks)} chunks in batch...") | |
| if len(chunks) == 1: | |
| # For single chunks, just use invoke directly | |
| prompt_text = create_grammar_prompt(chunks[0], proper_nouns) | |
| response = chain.invoke({"prompt": prompt_text}) | |
| content = response.content | |
| result = process_api_response(content) | |
| else: | |
| # For multiple chunks, use batch processing | |
| prompt_batch = [ | |
| {"prompt": create_grammar_prompt(chunk, proper_nouns)} | |
| for chunk in chunks | |
| ] | |
| responses = chain.batch(prompt_batch) | |
| logger.debug(f"Received {len(responses)} batch responses from API") | |
| # Process each response | |
| results = [] | |
| for response in responses: | |
| content = response.content | |
| result = process_api_response(content) | |
| results.append(result) | |
| # Merge the results | |
| result = merge_grammar_results(results) | |
| # Validate corrections to ensure they're meaningful | |
| validated_result = validate_corrections(result) | |
| # Create Error objects for each category | |
| spelling_errors = [Error(**err) for err in validated_result.get("spelling", [])] | |
| punctuation_errors = [ | |
| Error(**err) for err in validated_result.get("punctuation", []) | |
| ] | |
| grammar_errors = [Error(**err) for err in validated_result.get("grammar", [])] | |
| # Apply corrections to get fully corrected text | |
| corrected_text = apply_corrections( | |
| text, spelling_errors + punctuation_errors + grammar_errors | |
| ) | |
| # Return a Grammar object | |
| return Grammar( | |
| spelling=spelling_errors, | |
| punctuation=punctuation_errors, | |
| grammar=grammar_errors, | |
| file_path="", # Will be updated for file uploads | |
| corrected_text=corrected_text, # Add the corrected text | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error checking grammar: {str(e)}") | |
| raise Exception(f"Failed to analyze text: {str(e)}") | |
| def check_grammar_from_file( | |
| file_content: bytes, filename: str, proper_nouns: str = DEFAULT_PROPER_NOUNS | |
| ) -> Grammar: | |
| """ | |
| Check grammar from an uploaded file. | |
| Args: | |
| file_content: The bytes content of the file | |
| filename: The name of the uploaded file | |
| proper_nouns: A string of proper nouns to preserve | |
| Returns: | |
| Grammar object containing categorized errors | |
| """ | |
| try: | |
| _, file_extension = os.path.splitext(filename) | |
| text = extract_text_from_file(file_content, file_extension) | |
| # Check grammar on the extracted text | |
| grammar_result = check_grammar(text, proper_nouns) | |
| # Update the file path | |
| grammar_result.file_path = filename | |
| return grammar_result | |
| except Exception as e: | |
| logger.error(f"Error checking grammar from file: {str(e)}") | |
| raise Exception(f"Failed to analyze file: {str(e)}") | |
| def display_results(response: Grammar, path: str = "", repo_link: str = "") -> int: | |
| """ | |
| Display the grammar check results using Rich. | |
| Args: | |
| response: The Grammar object with check results | |
| path: Path to the file that was checked | |
| repo_link: Optional repository link (for GitHub URLs) | |
| Returns: | |
| Total number of errors found | |
| """ | |
| # Replace local file path with GitHub URL if repo_link is provided | |
| if repo_link and response.file_path: | |
| # Use os.path.split to handle path separators correctly | |
| parts = os.path.normpath(response.file_path).split(os.path.sep) | |
| relative_path = os.path.basename(response.file_path) | |
| path = f"{repo_link.rstrip('/')}/blob/main/{relative_path}" | |
| elif path: | |
| # Use the provided path | |
| pass | |
| elif response.file_path: | |
| # Use the file path from the response | |
| path = response.file_path | |
| else: | |
| # Default text | |
| path = "Text input" | |
| # Print the file path | |
| total_errors = 0 | |
| # Display each error category | |
| for category in ["spelling", "punctuation", "grammar"]: | |
| table = Table(title=f"{category.capitalize()} Corrections", box=ROUNDED) | |
| table.add_column("Original", justify="left", style="bold red") | |
| table.add_column("Corrected", justify="left", style="bold green") | |
| table.add_column("Explanation", justify="left", style="italic") | |
| errors = getattr(response, category) | |
| for error in errors: | |
| if error.before != error.after: | |
| table.add_row(error.before, error.after, error.explanation) | |
| table.add_row("", "", "") # Add an empty row for spacing | |
| total_errors += 1 | |
| if errors: | |
| print(table) | |
| else: | |
| no_errors_msg = f"No {category} errors found." | |
| return total_errors | |
| def check_grammar_questions_batch(questions: List[Dict[str, Any]], batch_size: int = 5) -> List[Dict[str, Any]]: | |
| """ | |
| Process multiple questions in batches for grammar checking. | |
| Args: | |
| questions: List of question dictionaries to process | |
| batch_size: Number of questions to process in each batch | |
| Returns: | |
| List of processed question dictionaries with grammar corrections | |
| """ | |
| system_message = """ | |
| You are a spellchecker for a batch of questions and answers related to IT and programming. | |
| You will be given multiple question and answer pairs. | |
| Check the grammar of each question and answer pair. | |
| Return a list of dictionaries with the same structure as the input, but with corrected text. | |
| If any fields have no errors, return the original value. | |
| """ | |
| def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| input_message = """ | |
| Here is a question to check: | |
| {data} | |
| """ | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_message), | |
| ("user", input_message) | |
| ]) | |
| class BatchGrammarResult(BaseModel): | |
| output: Dict[str, Any] = Field( | |
| ..., description="Dictionary with corrected text" | |
| ) | |
| wrong_locations: str = Field( | |
| ..., description="Error descriptions for the question" | |
| ) | |
| chain = prompt | llm.with_structured_output(BatchGrammarResult) | |
| # Create prompts for each question in the batch | |
| prompts = [{"data": question} for question in batch] | |
| logger.info(f"prompt {prompts}") | |
| # Process all questions in parallel using batch | |
| results = chain.batch(prompts) | |
| # Extract and combine results | |
| processed_results = [] | |
| for result in results: | |
| result = result.dict() | |
| processed_results.append({ | |
| **result["output"], | |
| "wrong_locations": result["wrong_locations"] | |
| }) | |
| return processed_results | |
| # Preprocess questions to include only relevant fields | |
| preprocessed_questions = [] | |
| for qa_dict in questions: | |
| processed_dict = {} | |
| if "Question" in qa_dict and not pd.isna(qa_dict["Question"]): | |
| processed_dict["Question"] = qa_dict["Question"] | |
| for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]: | |
| if option in qa_dict and not pd.isna(qa_dict[option]): | |
| processed_dict[option] = qa_dict[option] | |
| # Keep original metadata | |
| processed_dict["No."] = qa_dict.get("No.") | |
| processed_dict["Training content"] = qa_dict.get("Training content") | |
| processed_dict["Answer"] = qa_dict.get("Answer") | |
| preprocessed_questions.append(processed_dict) | |
| # Process questions in batches | |
| results = [] | |
| total_batches = (len(preprocessed_questions) + batch_size - 1) // batch_size | |
| logger.info(f"Processing {len(preprocessed_questions)} questions in {total_batches} batches") | |
| for i in range(0, len(preprocessed_questions), batch_size): | |
| batch = preprocessed_questions[i:i + batch_size] | |
| batch_num = (i // batch_size) + 1 | |
| logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} questions") | |
| batch_results = process_batch(batch) | |
| results.extend(batch_results) | |
| return results | |
| def process_grammar_check(input_file: str, output_file: str, limit: Optional[int] = None) -> str: | |
| """ | |
| Process an Excel file with questions and answers, check grammar, and save the corrected data. | |
| Args: | |
| input_file (str): Path to the input Excel file | |
| output_file (str): Path to save the output Excel file | |
| limit (int, optional): Limit the number of records to process. If None, process all records. | |
| Returns: | |
| str: Path to the output file | |
| """ | |
| # Read the input file | |
| df = pd.read_excel(input_file, sheet_name="Sheet1") | |
| records = df.to_dict(orient="records") | |
| if limit is not None: | |
| records = records[:limit] | |
| # Process the records in batches | |
| processed_records = check_grammar_questions_batch(records,batch_size=30) | |
| # Create a DataFrame from the processed data and write to Excel | |
| output_df = pd.DataFrame(processed_records) | |
| output_df.to_excel(output_file, index=False) | |
| return output_file | |