Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- grammar_checker.py +161 -25
- requirements.txt +1 -1
grammar_checker.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
from loguru import logger
|
| 3 |
import json
|
| 4 |
-
import io
|
| 5 |
import tempfile
|
| 6 |
from typing import List, Dict, Any, Annotated, Optional
|
| 7 |
from langchain_openai import AzureChatOpenAI
|
|
@@ -13,23 +12,35 @@ from rich.table import Table
|
|
| 13 |
from rich.box import ROUNDED
|
| 14 |
import re
|
| 15 |
import pandas as pd
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
temperature=0,
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Constants for text splitting
|
| 31 |
CHUNK_SIZE = 1000 # Approximate characters per page
|
| 32 |
-
CHUNK_OVERLAP =
|
| 33 |
|
| 34 |
# Common tech terms and proper nouns that should not be flagged as errors
|
| 35 |
DEFAULT_PROPER_NOUNS = """
|
|
@@ -63,11 +74,13 @@ def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]:
|
|
| 63 |
[("system", system_message), ("user", input_message)]
|
| 64 |
)
|
| 65 |
|
| 66 |
-
class GrammarResult(
|
| 67 |
-
output:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
wrong_locations:
|
|
|
|
|
|
|
| 71 |
|
| 72 |
chain = prompt | llm.with_structured_output(GrammarResult)
|
| 73 |
result = chain.invoke({"data": data})
|
|
@@ -88,13 +101,18 @@ def check_grammar_qa(
|
|
| 88 |
Dictionary with corrected text for each field
|
| 89 |
"""
|
| 90 |
corrected_dict = {}
|
| 91 |
-
|
| 92 |
# Only process the Question and Answer Options A-D
|
| 93 |
if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
|
| 94 |
corrected_dict["Question"] = qa_dict["Question"]
|
| 95 |
-
|
| 96 |
# Process answer options
|
| 97 |
-
for option in [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if option in qa_dict and not pd.isna(qa_dict[option]):
|
| 99 |
corrected_dict[option] = qa_dict[option]
|
| 100 |
|
|
@@ -374,7 +392,9 @@ def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Gramma
|
|
| 374 |
chunks = split_text(text)
|
| 375 |
|
| 376 |
# Initialize LangChain with Azure OpenAI
|
| 377 |
-
logger.debug(
|
|
|
|
|
|
|
| 378 |
|
| 379 |
# Create system message for JSON format
|
| 380 |
system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
|
|
@@ -521,11 +541,127 @@ def display_results(response: Grammar, path: str = "", repo_link: str = "") -> i
|
|
| 521 |
total_errors += 1
|
| 522 |
|
| 523 |
if errors:
|
| 524 |
-
|
| 525 |
else:
|
| 526 |
no_errors_msg = f"No {category} errors found."
|
| 527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
|
|
|
|
|
|
|
|
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from loguru import logger
|
| 3 |
import json
|
|
|
|
| 4 |
import tempfile
|
| 5 |
from typing import List, Dict, Any, Annotated, Optional
|
| 6 |
from langchain_openai import AzureChatOpenAI
|
|
|
|
| 12 |
from rich.box import ROUNDED
|
| 13 |
import re
|
| 14 |
import pandas as pd
|
| 15 |
+
import asyncio
|
| 16 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 17 |
+
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
| 18 |
+
from pydantic import BaseModel, Field
|
| 19 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 20 |
+
|
| 21 |
+
llm = ChatGoogleGenerativeAI(
|
| 22 |
+
model="gemini-2.0-flash-001",
|
| 23 |
temperature=0,
|
| 24 |
+
max_tokens=None,
|
| 25 |
+
timeout=None,
|
| 26 |
+
max_retries=2,
|
| 27 |
+
# other params...
|
| 28 |
)
|
| 29 |
+
# Get Azure OpenAI credentials from environment variables
|
| 30 |
+
# AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
|
| 31 |
+
# AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
| 32 |
+
# AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
|
| 33 |
+
# AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION")
|
| 34 |
+
# llm = AzureChatOpenAI(
|
| 35 |
+
# temperature=0,
|
| 36 |
+
# api_key=AZURE_OPENAI_API_KEY,
|
| 37 |
+
# azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
| 38 |
+
# azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
|
| 39 |
+
# api_version=AZURE_OPENAI_API_VERSION,
|
| 40 |
+
# )
|
| 41 |
# Constants for text splitting
|
| 42 |
CHUNK_SIZE = 1000 # Approximate characters per page
|
| 43 |
+
CHUNK_OVERLAP = 0 # Overlap between chunks to maintain context
|
| 44 |
|
| 45 |
# Common tech terms and proper nouns that should not be flagged as errors
|
| 46 |
DEFAULT_PROPER_NOUNS = """
|
|
|
|
| 74 |
[("system", system_message), ("user", input_message)]
|
| 75 |
)
|
| 76 |
|
| 77 |
+
class GrammarResult(BaseModel):
|
| 78 |
+
output: Dict[str, str] = Field(
|
| 79 |
+
..., description="A dictionary with same keys as the input dictionary."
|
| 80 |
+
)
|
| 81 |
+
wrong_locations: Optional[str] = Field(
|
| 82 |
+
None, description="point out errors briefly. Leave blank if there are no errors."
|
| 83 |
+
)
|
| 84 |
|
| 85 |
chain = prompt | llm.with_structured_output(GrammarResult)
|
| 86 |
result = chain.invoke({"data": data})
|
|
|
|
| 101 |
Dictionary with corrected text for each field
|
| 102 |
"""
|
| 103 |
corrected_dict = {}
|
| 104 |
+
|
| 105 |
# Only process the Question and Answer Options A-D
|
| 106 |
if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
|
| 107 |
corrected_dict["Question"] = qa_dict["Question"]
|
| 108 |
+
|
| 109 |
# Process answer options
|
| 110 |
+
for option in [
|
| 111 |
+
"Answer Option A",
|
| 112 |
+
"Answer Option B",
|
| 113 |
+
"Answer Option C",
|
| 114 |
+
"Answer Option D",
|
| 115 |
+
]:
|
| 116 |
if option in qa_dict and not pd.isna(qa_dict[option]):
|
| 117 |
corrected_dict[option] = qa_dict[option]
|
| 118 |
|
|
|
|
| 392 |
chunks = split_text(text)
|
| 393 |
|
| 394 |
# Initialize LangChain with Azure OpenAI
|
| 395 |
+
# logger.debug(
|
| 396 |
+
# f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}"
|
| 397 |
+
# )
|
| 398 |
|
| 399 |
# Create system message for JSON format
|
| 400 |
system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
|
|
|
|
| 541 |
total_errors += 1
|
| 542 |
|
| 543 |
if errors:
|
| 544 |
+
print(table)
|
| 545 |
else:
|
| 546 |
no_errors_msg = f"No {category} errors found."
|
| 547 |
|
| 548 |
+
return total_errors
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def check_grammar_questions_batch(questions: List[Dict[str, Any]], batch_size: int = 5) -> List[Dict[str, Any]]:
|
| 552 |
+
"""
|
| 553 |
+
Process multiple questions in batches for grammar checking.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
questions: List of question dictionaries to process
|
| 557 |
+
batch_size: Number of questions to process in each batch
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
List of processed question dictionaries with grammar corrections
|
| 561 |
+
"""
|
| 562 |
+
system_message = """
|
| 563 |
+
You are a spellchecker for a batch of questions and answers related to IT and programming.
|
| 564 |
+
You will be given multiple question and answer pairs.
|
| 565 |
+
Check the grammar of each question and answer pair.
|
| 566 |
+
Return a list of dictionaries with the same structure as the input, but with corrected text.
|
| 567 |
+
If any fields have no errors, return the original value.
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 571 |
+
input_message = """
|
| 572 |
+
Here is a question to check:
|
| 573 |
+
{data}
|
| 574 |
+
"""
|
| 575 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 576 |
+
("system", system_message),
|
| 577 |
+
("user", input_message)
|
| 578 |
+
])
|
| 579 |
+
|
| 580 |
+
class BatchGrammarResult(BaseModel):
|
| 581 |
+
output: Dict[str, Any] = Field(
|
| 582 |
+
..., description="Dictionary with corrected text"
|
| 583 |
+
)
|
| 584 |
+
wrong_locations: str = Field(
|
| 585 |
+
..., description="Error descriptions for the question"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
chain = prompt | llm.with_structured_output(BatchGrammarResult)
|
| 589 |
+
|
| 590 |
+
# Create prompts for each question in the batch
|
| 591 |
+
prompts = [{"data": question} for question in batch]
|
| 592 |
+
logger.info(f"prompt {prompts}")
|
| 593 |
+
# Process all questions in parallel using batch
|
| 594 |
+
results = chain.batch(prompts)
|
| 595 |
+
|
| 596 |
+
# Extract and combine results
|
| 597 |
+
processed_results = []
|
| 598 |
+
for result in results:
|
| 599 |
+
result = result.dict()
|
| 600 |
+
processed_results.append({
|
| 601 |
+
**result["output"],
|
| 602 |
+
"wrong_locations": result["wrong_locations"]
|
| 603 |
+
})
|
| 604 |
+
|
| 605 |
+
return processed_results
|
| 606 |
+
|
| 607 |
+
# Preprocess questions to include only relevant fields
|
| 608 |
+
preprocessed_questions = []
|
| 609 |
+
for qa_dict in questions:
|
| 610 |
+
processed_dict = {}
|
| 611 |
+
|
| 612 |
+
if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
|
| 613 |
+
processed_dict["Question"] = qa_dict["Question"]
|
| 614 |
+
|
| 615 |
+
for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]:
|
| 616 |
+
if option in qa_dict and not pd.isna(qa_dict[option]):
|
| 617 |
+
processed_dict[option] = qa_dict[option]
|
| 618 |
+
|
| 619 |
+
# Keep original metadata
|
| 620 |
+
processed_dict["No."] = qa_dict.get("No.")
|
| 621 |
+
processed_dict["Training content"] = qa_dict.get("Training content")
|
| 622 |
+
processed_dict["Answer"] = qa_dict.get("Answer")
|
| 623 |
+
|
| 624 |
+
preprocessed_questions.append(processed_dict)
|
| 625 |
+
|
| 626 |
+
# Process questions in batches
|
| 627 |
+
results = []
|
| 628 |
+
total_batches = (len(preprocessed_questions) + batch_size - 1) // batch_size
|
| 629 |
+
logger.info(f"Processing {len(preprocessed_questions)} questions in {total_batches} batches")
|
| 630 |
+
|
| 631 |
+
for i in range(0, len(preprocessed_questions), batch_size):
|
| 632 |
+
batch = preprocessed_questions[i:i + batch_size]
|
| 633 |
+
batch_num = (i // batch_size) + 1
|
| 634 |
+
logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} questions")
|
| 635 |
+
batch_results = process_batch(batch)
|
| 636 |
+
results.extend(batch_results)
|
| 637 |
+
|
| 638 |
+
return results
|
| 639 |
+
|
| 640 |
|
| 641 |
+
def process_grammar_check(input_file: str, output_file: str, limit: Optional[int] = None) -> str:
|
| 642 |
+
"""
|
| 643 |
+
Process an Excel file with questions and answers, check grammar, and save the corrected data.
|
| 644 |
|
| 645 |
+
Args:
|
| 646 |
+
input_file (str): Path to the input Excel file
|
| 647 |
+
output_file (str): Path to save the output Excel file
|
| 648 |
+
limit (int, optional): Limit the number of records to process. If None, process all records.
|
| 649 |
|
| 650 |
+
Returns:
|
| 651 |
+
str: Path to the output file
|
| 652 |
+
"""
|
| 653 |
+
# Read the input file
|
| 654 |
+
df = pd.read_excel(input_file, sheet_name="Sheet1")
|
| 655 |
+
records = df.to_dict(orient="records")
|
| 656 |
+
|
| 657 |
+
if limit is not None:
|
| 658 |
+
records = records[:limit]
|
| 659 |
+
|
| 660 |
+
# Process the records in batches
|
| 661 |
+
processed_records = check_grammar_questions_batch(records,batch_size=30)
|
| 662 |
+
|
| 663 |
+
# Create a DataFrame from the processed data and write to Excel
|
| 664 |
+
output_df = pd.DataFrame(processed_records)
|
| 665 |
+
output_df.to_excel(output_file, index=False)
|
| 666 |
+
|
| 667 |
+
return output_file
|
requirements.txt
CHANGED
|
@@ -10,4 +10,4 @@ langchain_text_splitters
|
|
| 10 |
rich
|
| 11 |
python-docx
|
| 12 |
python-multipart
|
| 13 |
-
|
|
|
|
| 10 |
rich
|
| 11 |
python-docx
|
| 12 |
python-multipart
|
| 13 |
+
langchain-google-genai
|