File size: 23,296 Bytes
094f1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
98ba910
 
 
 
 
 
 
 
094f1b1
98ba910
 
 
 
094f1b1
98ba910
 
 
 
 
 
 
 
 
 
 
 
094f1b1
 
98ba910
094f1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ba910
 
 
 
 
 
 
094f1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ba910
094f1b1
 
 
98ba910
094f1b1
98ba910
 
 
 
 
 
094f1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ba910
 
 
094f1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98ba910
094f1b1
 
 
98ba910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094f1b1
98ba910
 
 
094f1b1
98ba910
 
 
 
094f1b1
98ba910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
import os
from loguru import logger
import json
import tempfile
from typing import List, Dict, Any, Annotated, Optional
from langchain_openai import AzureChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from models import Grammar, Error
import docx
from rich.table import Table
from rich.box import ROUNDED
import re
import pandas as pd
import asyncio
from concurrent.futures import ThreadPoolExecutor
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash-001",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # other params...
)
# Get Azure OpenAI credentials from environment variables
# AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
# AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
# AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
# AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION")
# llm = AzureChatOpenAI(
#     temperature=0,
#     api_key=AZURE_OPENAI_API_KEY,
#     azure_endpoint=AZURE_OPENAI_ENDPOINT,
#     azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
#     api_version=AZURE_OPENAI_API_VERSION,
# )
# Constants for text splitting
CHUNK_SIZE = 1000  # Approximate characters per page
CHUNK_OVERLAP = 0  # Overlap between chunks to maintain context

# Common tech terms and proper nouns that should not be flagged as errors
DEFAULT_PROPER_NOUNS = """
API, APIs, HTML, CSS, JavaScript, TypeScript, Python, Java, C++, SQL, NoSQL, 
MongoDB, PostgreSQL, MySQL, Redis, Docker, Kubernetes, AWS, Azure, GCP, 
HTTP, HTTPS, REST, GraphQL, JSON, XML, YAML, React, Angular, Vue, Node.js, 
Express, Flask, Django, Spring, TensorFlow, PyTorch, Scikit-learn, npm, pip, 
GitHub, GitLab, Bitbucket, Jira, Confluence, Slack, OAuth, JWT, SSL, TLS
"""

from typing import TypedDict, Dict


def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]:
    """
    Check grammar for a question and return corrected text.
    """
    system_message = """
    You are a spellchecker for a question and answer pair. Related to IT and programming.
    You will be given a question and answer pair.
    You will need to check the grammar of the question and answer pair.
    You will need to return the corrected question and answer pair in a dictionary. If any of the fields are not errors, you should return the original value.

    Output should be a dictionary with same keys as the input dictionary.
    """
    input_message = """
    Here are input dictionary:
    {data}
    """
    prompt = ChatPromptTemplate.from_messages(
        [("system", system_message), ("user", input_message)]
    )

    class GrammarResult(BaseModel):
        output: Dict[str, str] = Field(
            ..., description="A dictionary with same keys as the input dictionary."
        )
        wrong_locations: Optional[str] = Field(
            None, description="point out errors briefly. Leave blank if there are no errors."
        )

    chain = prompt | llm.with_structured_output(GrammarResult)
    result = chain.invoke({"data": data})
    return result


def check_grammar_qa(
    qa_dict: Dict[str, Any], proper_nouns: str = DEFAULT_PROPER_NOUNS
) -> Dict[str, str]:
    """
    Check grammar for a QA dictionary and return corrected text.

    Args:
        qa_dict: Dictionary containing question and answer options
        proper_nouns: A string of proper nouns to preserve

    Returns:
        Dictionary with corrected text for each field
    """
    corrected_dict = {}

    # Only process the Question and Answer Options A-D
    if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
        corrected_dict["Question"] = qa_dict["Question"]

    # Process answer options
    for option in [
        "Answer Option A",
        "Answer Option B",
        "Answer Option C",
        "Answer Option D",
    ]:
        if option in qa_dict and not pd.isna(qa_dict[option]):
            corrected_dict[option] = qa_dict[option]

    return check_grammar_question(corrected_dict)


def extract_text_from_docx(file_content: bytes) -> str:
    """
    Extract text from a .docx file.

    Args:
        file_content: The bytes content of the .docx file

    Returns:
        The extracted text as a string
    """
    try:
        # Create a temporary file to save the content
        with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as temp_file:
            temp_file.write(file_content)
            temp_file_path = temp_file.name

        # Open the temporary docx file and extract text
        doc = docx.Document(temp_file_path)
        full_text = []
        for para in doc.paragraphs:
            full_text.append(para.text)

        # Clean up the temporary file
        os.unlink(temp_file_path)

        return "\n".join(full_text)
    except Exception as e:
        logger.error(f"Error extracting text from docx: {str(e)}")
        raise Exception(f"Failed to extract text from docx: {str(e)}")


def extract_text_from_file(file_content: bytes, file_extension: str) -> str:
    """
    Extract text from a file based on its extension.

    Args:
        file_content: The bytes content of the file
        file_extension: The file extension (.txt, .docx, etc.)

    Returns:
        The extracted text as a string
    """
    if file_extension.lower() == ".txt":
        # For txt files, simply decode the content
        return file_content.decode("utf-8", errors="replace")
    elif file_extension.lower() == ".docx":
        # For docx files, use the docx extraction function
        return extract_text_from_docx(file_content)
    else:
        raise ValueError(f"Unsupported file extension: {file_extension}")


class SentenceBasedTextSplitter(RecursiveCharacterTextSplitter):
    def __init__(self, chunk_size: int, chunk_overlap: int = 0):
        super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        self.chunk_size = chunk_size

    def split_text(self, text: str):
        sentence_endings = re.compile(r"(?<=[.!?])\s+")
        sentences = sentence_endings.split(text)

        chunks = []
        current_chunk = ""

        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= self.chunk_size:
                current_chunk += sentence + " "
            else:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                current_chunk = sentence + " "
        # Ensure the last chunk includes the remaining sentence if it exists
        if current_chunk:
            chunks.append(current_chunk.strip())

        return chunks


def split_text(text: str) -> List[str]:
    """
    Split text into chunks of appropriate size for processing.

    Args:
        text: The full text to split

    Returns:
        A list of text chunks
    """
    # splitter = RecursiveCharacterTextSplitter(
    #     chunk_size=CHUNK_SIZE,
    #     chunk_overlap=CHUNK_OVERLAP,
    #     length_function=len,
    #     is_separator_regex=False,
    # )
    splitter = SentenceBasedTextSplitter(
        chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
    )
    chunks = splitter.split_text(text)
    logger.debug(f"Split text into {len(chunks)} chunks")
    return chunks


def create_grammar_prompt(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> str:
    """
    Create a grammar checking prompt for the given text with proper nouns.

    Args:
        text: The text to check for grammar issues
        proper_nouns: A string of proper nouns to preserve

    Returns:
        A formatted prompt string
    """
    return f"""
    Rewrite the provided text to be clear and grammatically correct while preserving technical accuracy. Focus on:

    1. Correcting spelling, punctuation, and grammar errors
    2. Maintaining technical terminology and code snippets
    3. Ensuring consistent tense, voice, and formatting
    4. Clarifying function descriptions, parameters, and return values
    5. Proper use of capitalization, acronyms, and abbreviations
    6. Improving clarity and conciseness
    7. Respect markdown and code formatting such as underscores, asterisks, backticks, code blocks, and links
    8. Ensure proper nouns and acronyms are correctly spelled and capitalized
    
    Here's a list of proper nouns and technical terms you should preserve:
    {proper_nouns}

    Preserve code-specific formatting and syntax. Prioritize original text if unsure about technical terms.

    Make sure when you show the before vs after text, include a larger phrase or sentence for context.

    In the response:
    - For 'spelling', 'punctuation', and 'grammar' keys: Provide only changed items with original text, corrected text, and explanation.
    
    Ensure that the original text is actually referenced from the given text below:

    {text}
    """


def process_api_response(content: str) -> Dict[str, List[Dict[str, str]]]:
    """
    Process the API response to extract the JSON result.

    Args:
        content: The API response content

    Returns:
        A dictionary with grammar error categories
    """
    # Try to find JSON pattern
    json_start = content.find("{")
    json_end = content.rfind("}") + 1

    if json_start == -1 or json_end == 0:
        logger.error(f"Could not find JSON in response: {content}")
        raise ValueError("API response did not contain valid JSON")

    json_str = content[json_start:json_end]

    # Parse the JSON
    try:
        result = json.loads(json_str)
    except json.JSONDecodeError as je:
        logger.error(f"JSON decode error: {str(je)}")
        logger.error(f"JSON string was: {json_str}")
        # Create a default structure
        result = {"spelling": [], "punctuation": [], "grammar": []}

    return result


def merge_grammar_results(
    results: List[Dict[str, List[Dict[str, str]]]],
) -> Dict[str, List[Dict[str, str]]]:
    """
    Merge multiple grammar check results into a single result.

    Args:
        results: A list of grammar check results

    Returns:
        A merged grammar check result
    """
    merged = {"spelling": [], "punctuation": [], "grammar": []}

    for result in results:
        for category in ["spelling", "punctuation", "grammar"]:
            if category in result:
                merged[category].extend(result[category])

    return merged


def validate_corrections(
    result: Dict[str, List[Dict[str, str]]],
) -> Dict[str, List[Dict[str, str]]]:
    """
    Validate grammar corrections to ensure they're meaningful.

    Args:
        result: The grammar check result

    Returns:
        Validated grammar check result
    """
    validated = {"spelling": [], "punctuation": [], "grammar": []}

    for category in ["spelling", "punctuation", "grammar"]:
        for error in result.get(category, []):
            # Skip if before and after are the same
            if error["before"] == error["after"]:
                continue

            # Skip if only whitespace changes
            if error["before"].strip() == error["after"].strip():
                continue

            validated[category].append(error)

    return validated


def apply_corrections(original_text: str, errors: List[Error]) -> str:
    """
    Apply all grammar corrections to the original text.

    Args:
        original_text: The original text with errors
        errors: List of Error objects with before/after corrections

    Returns:
        Fully corrected text
    """
    # Process individual errors one at a time
    # Make a copy of the original text
    corrected = original_text

    # First, find the position of each error in the original text
    error_positions = []
    for error in errors:
        pos = corrected.find(error.before)
        if pos != -1:
            error_positions.append((pos, error))

    # Sort by position in descending order (to replace from end to start)
    # This way, earlier replacements don't affect positions of later ones
    error_positions.sort(key=lambda x: x[0], reverse=True)

    # Apply each correction
    for pos, error in error_positions:
        corrected = corrected[:pos] + error.after + corrected[pos + len(error.before) :]

    return corrected


def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Grammar:
    """
    Check the grammar of the given text using LangChain and Azure OpenAI.

    Args:
        text: The text to check for grammar issues
        proper_nouns: A string of proper nouns to preserve

    Returns:
        Grammar object containing categorized errors
    """
    try:
        # Split text into chunks if it's too long
        chunks = split_text(text)

        # Initialize LangChain with Azure OpenAI
        # logger.debug(
        #     f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}"
        # )

        # Create system message for JSON format
        system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
The JSON object must use the schema that has 'spelling', 'punctuation', and 'grammar' keys, each with a list of objects containing 'before', 'after', and 'explanation'.
It is strictly imperative that you return as JSON. DO NOT return any other characters other than valid JSON as your response."""

        # Create a prompt template and chain using the pipe syntax
        prompt = ChatPromptTemplate.from_messages(
            [("system", system_message), ("user", "{prompt}")]
        )
        chain = prompt | llm

        # Process each chunk in a batch
        logger.debug(f"Processing {len(chunks)} chunks in batch...")

        if len(chunks) == 1:
            # For single chunks, just use invoke directly
            prompt_text = create_grammar_prompt(chunks[0], proper_nouns)
            response = chain.invoke({"prompt": prompt_text})
            content = response.content
            result = process_api_response(content)
        else:
            # For multiple chunks, use batch processing
            prompt_batch = [
                {"prompt": create_grammar_prompt(chunk, proper_nouns)}
                for chunk in chunks
            ]
            responses = chain.batch(prompt_batch)
            logger.debug(f"Received {len(responses)} batch responses from API")

            # Process each response
            results = []
            for response in responses:
                content = response.content
                result = process_api_response(content)
                results.append(result)

            # Merge the results
            result = merge_grammar_results(results)

        # Validate corrections to ensure they're meaningful
        validated_result = validate_corrections(result)

        # Create Error objects for each category
        spelling_errors = [Error(**err) for err in validated_result.get("spelling", [])]
        punctuation_errors = [
            Error(**err) for err in validated_result.get("punctuation", [])
        ]
        grammar_errors = [Error(**err) for err in validated_result.get("grammar", [])]

        # Apply corrections to get fully corrected text
        corrected_text = apply_corrections(
            text, spelling_errors + punctuation_errors + grammar_errors
        )

        # Return a Grammar object
        return Grammar(
            spelling=spelling_errors,
            punctuation=punctuation_errors,
            grammar=grammar_errors,
            file_path="",  # Will be updated for file uploads
            corrected_text=corrected_text,  # Add the corrected text
        )

    except Exception as e:
        logger.error(f"Error checking grammar: {str(e)}")
        raise Exception(f"Failed to analyze text: {str(e)}")


def check_grammar_from_file(
    file_content: bytes, filename: str, proper_nouns: str = DEFAULT_PROPER_NOUNS
) -> Grammar:
    """
    Check grammar from an uploaded file.

    Args:
        file_content: The bytes content of the file
        filename: The name of the uploaded file
        proper_nouns: A string of proper nouns to preserve

    Returns:
        Grammar object containing categorized errors
    """
    try:
        _, file_extension = os.path.splitext(filename)
        text = extract_text_from_file(file_content, file_extension)

        # Check grammar on the extracted text
        grammar_result = check_grammar(text, proper_nouns)

        # Update the file path
        grammar_result.file_path = filename

        return grammar_result
    except Exception as e:
        logger.error(f"Error checking grammar from file: {str(e)}")
        raise Exception(f"Failed to analyze file: {str(e)}")


def display_results(response: Grammar, path: str = "", repo_link: str = "") -> int:
    """
    Display the grammar check results using Rich.

    Args:
        response: The Grammar object with check results
        path: Path to the file that was checked
        repo_link: Optional repository link (for GitHub URLs)

    Returns:
        Total number of errors found
    """
    # Replace local file path with GitHub URL if repo_link is provided
    if repo_link and response.file_path:
        # Use os.path.split to handle path separators correctly
        parts = os.path.normpath(response.file_path).split(os.path.sep)
        relative_path = os.path.basename(response.file_path)
        path = f"{repo_link.rstrip('/')}/blob/main/{relative_path}"
    elif path:
        # Use the provided path
        pass
    elif response.file_path:
        # Use the file path from the response
        path = response.file_path
    else:
        # Default text
        path = "Text input"

    # Print the file path

    total_errors = 0

    # Display each error category
    for category in ["spelling", "punctuation", "grammar"]:
        table = Table(title=f"{category.capitalize()} Corrections", box=ROUNDED)
        table.add_column("Original", justify="left", style="bold red")
        table.add_column("Corrected", justify="left", style="bold green")
        table.add_column("Explanation", justify="left", style="italic")

        errors = getattr(response, category)
        for error in errors:
            if error.before != error.after:
                table.add_row(error.before, error.after, error.explanation)
                table.add_row("", "", "")  # Add an empty row for spacing
                total_errors += 1

        if errors:
            print(table)
        else:
            no_errors_msg = f"No {category} errors found."

    return total_errors


def check_grammar_questions_batch(questions: List[Dict[str, Any]], batch_size: int = 5) -> List[Dict[str, Any]]:
    """
    Process multiple questions in batches for grammar checking.
    
    Args:
        questions: List of question dictionaries to process
        batch_size: Number of questions to process in each batch
        
    Returns:
        List of processed question dictionaries with grammar corrections
    """
    system_message = """
    You are a spellchecker for a batch of questions and answers related to IT and programming.
    You will be given multiple question and answer pairs.
    Check the grammar of each question and answer pair.
    Return a list of dictionaries with the same structure as the input, but with corrected text.
    If any fields have no errors, return the original value.
    """
    
    def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        input_message = """
        Here is a question to check:
        {data}
        """
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_message),
            ("user", input_message)
        ])
        
        class BatchGrammarResult(BaseModel):
            output: Dict[str, Any] = Field(
                ..., description="Dictionary with corrected text"
            )
            wrong_locations: str = Field(
                ..., description="Error descriptions for the question"
            )
            
        chain = prompt | llm.with_structured_output(BatchGrammarResult)
        
        # Create prompts for each question in the batch
        prompts = [{"data": question} for question in batch]
        logger.info(f"prompt {prompts}")
        # Process all questions in parallel using batch
        results = chain.batch(prompts)
        
        # Extract and combine results
        processed_results = []
        for result in results:
            result = result.dict()
            processed_results.append({
                **result["output"],
                "wrong_locations": result["wrong_locations"]
            })
            
        return processed_results
    
    # Preprocess questions to include only relevant fields
    preprocessed_questions = []
    for qa_dict in questions:
        processed_dict = {}
        
        if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
            processed_dict["Question"] = qa_dict["Question"]
            
        for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]:
            if option in qa_dict and not pd.isna(qa_dict[option]):
                processed_dict[option] = qa_dict[option]
                
        # Keep original metadata
        processed_dict["No."] = qa_dict.get("No.")
        processed_dict["Training content"] = qa_dict.get("Training content")
        processed_dict["Answer"] = qa_dict.get("Answer")
        
        preprocessed_questions.append(processed_dict)
    
    # Process questions in batches
    results = []
    total_batches = (len(preprocessed_questions) + batch_size - 1) // batch_size
    logger.info(f"Processing {len(preprocessed_questions)} questions in {total_batches} batches")
    
    for i in range(0, len(preprocessed_questions), batch_size):
        batch = preprocessed_questions[i:i + batch_size]
        batch_num = (i // batch_size) + 1
        logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} questions")
        batch_results = process_batch(batch)
        results.extend(batch_results)
    
    return results


def process_grammar_check(input_file: str, output_file: str, limit: Optional[int] = None) -> str:
    """
    Process an Excel file with questions and answers, check grammar, and save the corrected data.

    Args:
        input_file (str): Path to the input Excel file
        output_file (str): Path to save the output Excel file
        limit (int, optional): Limit the number of records to process. If None, process all records.

    Returns:
        str: Path to the output file
    """
    # Read the input file
    df = pd.read_excel(input_file, sheet_name="Sheet1")
    records = df.to_dict(orient="records")
    
    if limit is not None:
        records = records[:limit]
    
    # Process the records in batches
    processed_records = check_grammar_questions_batch(records,batch_size=30)
    
    # Create a DataFrame from the processed data and write to Excel
    output_df = pd.DataFrame(processed_records)
    output_df.to_excel(output_file, index=False)
    
    return output_file