ABAO77 commited on
Commit
98ba910
·
verified ·
1 Parent(s): 953da1c

Upload 2 files

Browse files
Files changed (2) hide show
  1. grammar_checker.py +161 -25
  2. requirements.txt +1 -1
grammar_checker.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from loguru import logger
3
  import json
4
- import io
5
  import tempfile
6
  from typing import List, Dict, Any, Annotated, Optional
7
  from langchain_openai import AzureChatOpenAI
@@ -13,23 +12,35 @@ from rich.table import Table
13
  from rich.box import ROUNDED
14
  import re
15
  import pandas as pd
16
-
17
-
18
- # Get Azure OpenAI credentials from environment variables
19
- AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
20
- AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
21
- AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
22
- AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION")
23
- llm = AzureChatOpenAI(
24
  temperature=0,
25
- api_key=AZURE_OPENAI_API_KEY,
26
- azure_endpoint=AZURE_OPENAI_ENDPOINT,
27
- azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
28
- api_version=AZURE_OPENAI_API_VERSION,
29
  )
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Constants for text splitting
31
  CHUNK_SIZE = 1000 # Approximate characters per page
32
- CHUNK_OVERLAP = 250 # Overlap between chunks to maintain context
33
 
34
  # Common tech terms and proper nouns that should not be flagged as errors
35
  DEFAULT_PROPER_NOUNS = """
@@ -63,11 +74,13 @@ def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]:
63
  [("system", system_message), ("user", input_message)]
64
  )
65
 
66
- class GrammarResult(TypedDict):
67
- output: Annotated[
68
- Dict[str, str], ..., "A dictionary with same keys as the input dictionary."
69
- ]
70
- wrong_locations: Annotated[Optional[str], ..., "point out errors briefly. Leave blank if there are no errors."]
 
 
71
 
72
  chain = prompt | llm.with_structured_output(GrammarResult)
73
  result = chain.invoke({"data": data})
@@ -88,13 +101,18 @@ def check_grammar_qa(
88
  Dictionary with corrected text for each field
89
  """
90
  corrected_dict = {}
91
-
92
  # Only process the Question and Answer Options A-D
93
  if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
94
  corrected_dict["Question"] = qa_dict["Question"]
95
-
96
  # Process answer options
97
- for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]:
 
 
 
 
 
98
  if option in qa_dict and not pd.isna(qa_dict[option]):
99
  corrected_dict[option] = qa_dict[option]
100
 
@@ -374,7 +392,9 @@ def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Gramma
374
  chunks = split_text(text)
375
 
376
  # Initialize LangChain with Azure OpenAI
377
- logger.debug(f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}")
 
 
378
 
379
  # Create system message for JSON format
380
  system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
@@ -521,11 +541,127 @@ def display_results(response: Grammar, path: str = "", repo_link: str = "") -> i
521
  total_errors += 1
522
 
523
  if errors:
524
- print(table)
525
  else:
526
  no_errors_msg = f"No {category} errors found."
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
 
 
 
529
 
 
 
 
 
530
 
531
- return total_errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from loguru import logger
3
  import json
 
4
  import tempfile
5
  from typing import List, Dict, Any, Annotated, Optional
6
  from langchain_openai import AzureChatOpenAI
 
12
  from rich.box import ROUNDED
13
  import re
14
  import pandas as pd
15
+ import asyncio
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
18
+ from pydantic import BaseModel, Field
19
+ from langchain_google_genai import ChatGoogleGenerativeAI
20
+
21
+ llm = ChatGoogleGenerativeAI(
22
+ model="gemini-2.0-flash-001",
23
  temperature=0,
24
+ max_tokens=None,
25
+ timeout=None,
26
+ max_retries=2,
27
+ # other params...
28
  )
29
+ # Get Azure OpenAI credentials from environment variables
30
+ # AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
31
+ # AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
32
+ # AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
33
+ # AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION")
34
+ # llm = AzureChatOpenAI(
35
+ # temperature=0,
36
+ # api_key=AZURE_OPENAI_API_KEY,
37
+ # azure_endpoint=AZURE_OPENAI_ENDPOINT,
38
+ # azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
39
+ # api_version=AZURE_OPENAI_API_VERSION,
40
+ # )
41
  # Constants for text splitting
42
  CHUNK_SIZE = 1000 # Approximate characters per page
43
+ CHUNK_OVERLAP = 0 # Overlap between chunks to maintain context
44
 
45
  # Common tech terms and proper nouns that should not be flagged as errors
46
  DEFAULT_PROPER_NOUNS = """
 
74
  [("system", system_message), ("user", input_message)]
75
  )
76
 
77
+ class GrammarResult(BaseModel):
78
+ output: Dict[str, str] = Field(
79
+ ..., description="A dictionary with same keys as the input dictionary."
80
+ )
81
+ wrong_locations: Optional[str] = Field(
82
+ None, description="point out errors briefly. Leave blank if there are no errors."
83
+ )
84
 
85
  chain = prompt | llm.with_structured_output(GrammarResult)
86
  result = chain.invoke({"data": data})
 
101
  Dictionary with corrected text for each field
102
  """
103
  corrected_dict = {}
104
+
105
  # Only process the Question and Answer Options A-D
106
  if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
107
  corrected_dict["Question"] = qa_dict["Question"]
108
+
109
  # Process answer options
110
+ for option in [
111
+ "Answer Option A",
112
+ "Answer Option B",
113
+ "Answer Option C",
114
+ "Answer Option D",
115
+ ]:
116
  if option in qa_dict and not pd.isna(qa_dict[option]):
117
  corrected_dict[option] = qa_dict[option]
118
 
 
392
  chunks = split_text(text)
393
 
394
  # Initialize LangChain with Azure OpenAI
395
+ # logger.debug(
396
+ # f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}"
397
+ # )
398
 
399
  # Create system message for JSON format
400
  system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
 
541
  total_errors += 1
542
 
543
  if errors:
544
+ print(table)
545
  else:
546
  no_errors_msg = f"No {category} errors found."
547
 
548
+ return total_errors
549
+
550
+
551
+ def check_grammar_questions_batch(questions: List[Dict[str, Any]], batch_size: int = 5) -> List[Dict[str, Any]]:
552
+ """
553
+ Process multiple questions in batches for grammar checking.
554
+
555
+ Args:
556
+ questions: List of question dictionaries to process
557
+ batch_size: Number of questions to process in each batch
558
+
559
+ Returns:
560
+ List of processed question dictionaries with grammar corrections
561
+ """
562
+ system_message = """
563
+ You are a spellchecker for a batch of questions and answers related to IT and programming.
564
+ You will be given multiple question and answer pairs.
565
+ Check the grammar of each question and answer pair.
566
+ Return a list of dictionaries with the same structure as the input, but with corrected text.
567
+ If any fields have no errors, return the original value.
568
+ """
569
+
570
+ def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
571
+ input_message = """
572
+ Here is a question to check:
573
+ {data}
574
+ """
575
+ prompt = ChatPromptTemplate.from_messages([
576
+ ("system", system_message),
577
+ ("user", input_message)
578
+ ])
579
+
580
+ class BatchGrammarResult(BaseModel):
581
+ output: Dict[str, Any] = Field(
582
+ ..., description="Dictionary with corrected text"
583
+ )
584
+ wrong_locations: str = Field(
585
+ ..., description="Error descriptions for the question"
586
+ )
587
+
588
+ chain = prompt | llm.with_structured_output(BatchGrammarResult)
589
+
590
+ # Create prompts for each question in the batch
591
+ prompts = [{"data": question} for question in batch]
592
+ logger.info(f"prompt {prompts}")
593
+ # Process all questions in parallel using batch
594
+ results = chain.batch(prompts)
595
+
596
+ # Extract and combine results
597
+ processed_results = []
598
+ for result in results:
599
+ result = result.dict()
600
+ processed_results.append({
601
+ **result["output"],
602
+ "wrong_locations": result["wrong_locations"]
603
+ })
604
+
605
+ return processed_results
606
+
607
+ # Preprocess questions to include only relevant fields
608
+ preprocessed_questions = []
609
+ for qa_dict in questions:
610
+ processed_dict = {}
611
+
612
+ if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
613
+ processed_dict["Question"] = qa_dict["Question"]
614
+
615
+ for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]:
616
+ if option in qa_dict and not pd.isna(qa_dict[option]):
617
+ processed_dict[option] = qa_dict[option]
618
+
619
+ # Keep original metadata
620
+ processed_dict["No."] = qa_dict.get("No.")
621
+ processed_dict["Training content"] = qa_dict.get("Training content")
622
+ processed_dict["Answer"] = qa_dict.get("Answer")
623
+
624
+ preprocessed_questions.append(processed_dict)
625
+
626
+ # Process questions in batches
627
+ results = []
628
+ total_batches = (len(preprocessed_questions) + batch_size - 1) // batch_size
629
+ logger.info(f"Processing {len(preprocessed_questions)} questions in {total_batches} batches")
630
+
631
+ for i in range(0, len(preprocessed_questions), batch_size):
632
+ batch = preprocessed_questions[i:i + batch_size]
633
+ batch_num = (i // batch_size) + 1
634
+ logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} questions")
635
+ batch_results = process_batch(batch)
636
+ results.extend(batch_results)
637
+
638
+ return results
639
+
640
 
641
+ def process_grammar_check(input_file: str, output_file: str, limit: Optional[int] = None) -> str:
642
+ """
643
+ Process an Excel file with questions and answers, check grammar, and save the corrected data.
644
 
645
+ Args:
646
+ input_file (str): Path to the input Excel file
647
+ output_file (str): Path to save the output Excel file
648
+ limit (int, optional): Limit the number of records to process. If None, process all records.
649
 
650
+ Returns:
651
+ str: Path to the output file
652
+ """
653
+ # Read the input file
654
+ df = pd.read_excel(input_file, sheet_name="Sheet1")
655
+ records = df.to_dict(orient="records")
656
+
657
+ if limit is not None:
658
+ records = records[:limit]
659
+
660
+ # Process the records in batches
661
+ processed_records = check_grammar_questions_batch(records,batch_size=30)
662
+
663
+ # Create a DataFrame from the processed data and write to Excel
664
+ output_df = pd.DataFrame(processed_records)
665
+ output_df.to_excel(output_file, index=False)
666
+
667
+ return output_file
requirements.txt CHANGED
@@ -10,4 +10,4 @@ langchain_text_splitters
10
  rich
11
  python-docx
12
  python-multipart
13
- openpyxl
 
10
  rich
11
  python-docx
12
  python-multipart
13
+ langchain-google-genai