ABAO77 commited on
Commit
094f1b1
·
verified ·
1 Parent(s): 067ae78

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. app.py +217 -0
  3. grammar_checker.py +555 -0
  4. models.py +16 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11-slim
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import sys
4
+ import argparse
5
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks
6
+ from fastapi.responses import JSONResponse, FileResponse
7
+ from fastapi.staticfiles import StaticFiles
8
+ from fastapi.templating import Jinja2Templates
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from typing import Optional, List, Dict, Any
12
+ import pandas as pd
13
+ import tempfile
14
+ import shutil
15
+ from models import Grammar, Error
16
+ from grammar_checker import (
17
+ check_grammar,
18
+ check_grammar_from_file,
19
+ check_grammar_qa,
20
+ display_results,
21
+ DEFAULT_PROPER_NOUNS,
22
+ )
23
+
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.DEBUG)
26
+
27
+ # Create FastAPI app
28
+ app = FastAPI(
29
+ title="Grammar Checker API",
30
+ description="API for checking grammar in text, files, and quiz questions",
31
+ docs_url="/",
32
+ )
33
+
34
+ # Configure CORS
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"], # Allows all origins
38
+ allow_credentials=True,
39
+ allow_methods=["*"], # Allows all methods
40
+ allow_headers=["*"], # Allows all headers
41
+ )
42
+
43
+ # Define allowed file extensions
44
+ ALLOWED_EXTENSIONS = {".txt", ".docx", ".xlsx"}
45
+
46
+
47
+ class GrammarCheckResponse(BaseModel):
48
+ output_file: str
49
+ message: str
50
+ records: List[Dict[str, Any]] = []
51
+
52
+
53
+ @app.post("/check-grammar-quiz/", response_model=GrammarCheckResponse)
54
+ async def check_grammar(
55
+ background_tasks: BackgroundTasks,
56
+ file: UploadFile = File(...),
57
+ limit: Optional[int] = None,
58
+ ):
59
+ """
60
+ Process an Excel file with questions and answers, check grammar, and return the corrected data.
61
+
62
+ Args:
63
+ file: The input Excel file
64
+ limit: Limit the number of records to process. If None, process all records.
65
+
66
+ Returns:
67
+ JSON response with the path to the output file and processed records
68
+ """
69
+ # Create temp directory to store files
70
+ temp_dir = tempfile.mkdtemp()
71
+ input_path = os.path.join(temp_dir, file.filename)
72
+ output_filename = f"corrected_{file.filename}"
73
+ output_path = os.path.join(temp_dir, output_filename)
74
+
75
+ # Save uploaded file
76
+ try:
77
+ with open(input_path, "wb") as buffer:
78
+ shutil.copyfileobj(file.file, buffer)
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}")
81
+
82
+ # Process the file
83
+ try:
84
+ result_file, processed_records = process_grammar_check(input_path, output_path, limit)
85
+ background_tasks.add_task(cleanup_temp_files, temp_dir)
86
+ return GrammarCheckResponse(
87
+ output_file=result_file,
88
+ message="Grammar check completed successfully",
89
+ records=processed_records
90
+ )
91
+ except Exception as e:
92
+ background_tasks.add_task(cleanup_temp_files, temp_dir)
93
+ raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
94
+
95
+
96
+ def cleanup_temp_files(temp_dir: str):
97
+ """Clean up temporary files"""
98
+ shutil.rmtree(temp_dir)
99
+
100
+
101
+ def process_grammar_check(input_file, output_file, limit=None):
102
+ """
103
+ Process an Excel file with questions and answers, check grammar, and save the corrected data.
104
+
105
+ Args:
106
+ input_file (str): Path to the input Excel file
107
+ output_file (str): Path to save the output Excel file
108
+ limit (int, optional): Limit the number of records to process. If None, process all records.
109
+
110
+ Returns:
111
+ tuple: (Path to the output file, List of processed records)
112
+ """
113
+ # Read the input file
114
+ df = pd.read_excel(input_file, sheet_name="Sheet1")
115
+ records = df.to_dict(orient="records")
116
+
117
+ # Process the records
118
+ data_processed = []
119
+ for i, record in enumerate(records):
120
+ if limit is not None and i >= limit:
121
+ break
122
+
123
+ dict_result = check_grammar_qa(record)
124
+ temp_dict = record.copy()
125
+ temp_dict["Question"] = dict_result["output"]["Question"]
126
+ temp_dict["Answer Option A"] = dict_result["output"].get(
127
+ "Answer Option A", None
128
+ )
129
+ temp_dict["Answer Option B"] = dict_result["output"].get(
130
+ "Answer Option B", None
131
+ )
132
+ temp_dict["Answer Option C"] = dict_result["output"].get(
133
+ "Answer Option C", None
134
+ )
135
+ temp_dict["Answer Option D"] = dict_result["output"].get(
136
+ "Answer Option D", None
137
+ )
138
+ temp_dict["wrong_locations"] = dict_result["wrong_locations"]
139
+
140
+ data_processed.append(temp_dict)
141
+
142
+ # Create a DataFrame from the processed data and write to Excel
143
+ output_df = pd.DataFrame(data_processed)
144
+ output_df.to_excel(output_file, index=False)
145
+
146
+ return output_file, data_processed
147
+
148
+
149
+ def allowed_file(filename: str) -> bool:
150
+ """
151
+ Check if the uploaded file has an allowed extension.
152
+
153
+ Args:
154
+ filename: The name of the uploaded file
155
+
156
+ Returns:
157
+ True if the file extension is allowed, False otherwise
158
+ """
159
+ return os.path.splitext(filename)[1].lower() in ALLOWED_EXTENSIONS
160
+
161
+
162
+ class TextRequest(BaseModel):
163
+ text: str
164
+ proper_nouns: Optional[str] = DEFAULT_PROPER_NOUNS
165
+
166
+
167
+ @app.post("/check-text")
168
+ async def check_text(body: TextRequest):
169
+ """Process text input and check grammar."""
170
+ try:
171
+ if not body.text:
172
+ raise HTTPException(status_code=400, detail="No text provided")
173
+
174
+ # Check grammar using LangChain
175
+ result = check_grammar(body.text, body.proper_nouns)
176
+
177
+ # Convert Pydantic model to dict for JSON response
178
+ return result.dict()
179
+
180
+ except Exception as e:
181
+ logging.error(f"Error processing text: {str(e)}")
182
+ raise HTTPException(status_code=500, detail=str(e))
183
+
184
+
185
+ @app.post("/check-file")
186
+ async def check_file(
187
+ file: UploadFile, proper_nouns: Optional[str] = Form(DEFAULT_PROPER_NOUNS)
188
+ ):
189
+ """Process file upload and check grammar."""
190
+ try:
191
+ # Check if a valid file was uploaded
192
+ if file.filename == "":
193
+ raise HTTPException(status_code=400, detail="No file selected")
194
+
195
+ # Check if the file has an allowed extension
196
+ if not allowed_file(file.filename):
197
+ raise HTTPException(
198
+ status_code=400,
199
+ detail=f"File type not supported. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}",
200
+ )
201
+
202
+ # Process the file
203
+ file_content = await file.read()
204
+ result = check_grammar_from_file(file_content, file.filename, proper_nouns)
205
+
206
+ # Convert Pydantic model to dict for JSON response
207
+ return result.dict()
208
+
209
+ except Exception as e:
210
+ logging.error(f"Error processing file: {str(e)}")
211
+ raise HTTPException(status_code=500, detail=str(e))
212
+
213
+
214
+ if __name__ == "__main__":
215
+ import uvicorn
216
+
217
+ uvicorn.run(app, host="0.0.0.0", port=8080)
grammar_checker.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from loguru import logger
3
+ import json
4
+ import io
5
+ import tempfile
6
+ from typing import List, Dict, Any, Annotated, Optional
7
+ from langchain_openai import AzureChatOpenAI
8
+ from langchain.prompts import ChatPromptTemplate
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from models import Grammar, Error
11
+ import docx
12
+ from rich.console import Console
13
+ from rich.table import Table
14
+ from rich.box import ROUNDED
15
+ import re
16
+ import pandas as pd
17
+
18
+ # Configure logging
19
+ logger.add("grammar_checker.log", rotation="500 MB", level="DEBUG")
20
+
21
+ # Create console for rich output
22
+ console = Console()
23
+
24
+ # Get Azure OpenAI credentials from environment variables
25
+ AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
26
+ AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
27
+ AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get("AZURE_OPENAI_DEPLOYMENT")
28
+ AZURE_OPENAI_API_VERSION = os.environ.get("API_VERSION")
29
+ llm = AzureChatOpenAI(
30
+ temperature=0,
31
+ api_key=AZURE_OPENAI_API_KEY,
32
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
33
+ azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
34
+ api_version=AZURE_OPENAI_API_VERSION,
35
+ model="gpt-4o",
36
+ )
37
+ # Constants for text splitting
38
+ CHUNK_SIZE = 1000 # Approximate characters per page
39
+ CHUNK_OVERLAP = 250 # Overlap between chunks to maintain context
40
+
41
+ # Common tech terms and proper nouns that should not be flagged as errors
42
+ DEFAULT_PROPER_NOUNS = """
43
+ API, APIs, HTML, CSS, JavaScript, TypeScript, Python, Java, C++, SQL, NoSQL,
44
+ MongoDB, PostgreSQL, MySQL, Redis, Docker, Kubernetes, AWS, Azure, GCP,
45
+ HTTP, HTTPS, REST, GraphQL, JSON, XML, YAML, React, Angular, Vue, Node.js,
46
+ Express, Flask, Django, Spring, TensorFlow, PyTorch, Scikit-learn, npm, pip,
47
+ GitHub, GitLab, Bitbucket, Jira, Confluence, Slack, OAuth, JWT, SSL, TLS
48
+ """
49
+
50
+ from typing import TypedDict, Dict
51
+
52
+
53
+ def check_grammar_question(data: Dict[str, Any]) -> Dict[str, str]:
54
+ """
55
+ Check grammar for a question and return corrected text.
56
+ """
57
+ system_message = """
58
+ You are a spellchecker for a question and answer pair. Related to IT and programming.
59
+ You will be given a question and answer pair.
60
+ You will need to check the grammar of the question and answer pair.
61
+ You will need to return the corrected question and answer pair in a dictionary. If any of the fields are not errors, you should return the original value.
62
+
63
+ Output should be a dictionary with same keys as the input dictionary.
64
+ """
65
+ input_message = """
66
+ Here are input dictionary:
67
+ {data}
68
+ """
69
+ prompt = ChatPromptTemplate.from_messages(
70
+ [("system", system_message), ("user", input_message)]
71
+ )
72
+
73
+ class GrammarResult(TypedDict):
74
+ output: Annotated[
75
+ Dict[str, str], ..., "A dictionary with same keys as the input dictionary."
76
+ ]
77
+ wrong_locations: Annotated[Optional[str], ..., "point out errors briefly. Leave blank if there are no errors."]
78
+
79
+ chain = prompt | llm.with_structured_output(GrammarResult)
80
+ result = chain.invoke({"data": data})
81
+ return result
82
+
83
+
84
+ def check_grammar_qa(
85
+ qa_dict: Dict[str, Any], proper_nouns: str = DEFAULT_PROPER_NOUNS
86
+ ) -> Dict[str, str]:
87
+ """
88
+ Check grammar for a QA dictionary and return corrected text.
89
+
90
+ Args:
91
+ qa_dict: Dictionary containing question and answer options
92
+ proper_nouns: A string of proper nouns to preserve
93
+
94
+ Returns:
95
+ Dictionary with corrected text for each field
96
+ """
97
+ corrected_dict = {}
98
+
99
+ # Only process the Question and Answer Options A-D
100
+ if "Question" in qa_dict and not pd.isna(qa_dict["Question"]):
101
+ corrected_dict["Question"] = qa_dict["Question"]
102
+
103
+ # Process answer options
104
+ for option in ["Answer Option A", "Answer Option B", "Answer Option C", "Answer Option D"]:
105
+ if option in qa_dict and not pd.isna(qa_dict[option]):
106
+ corrected_dict[option] = qa_dict[option]
107
+
108
+ return check_grammar_question(corrected_dict)
109
+
110
+
111
+ def extract_text_from_docx(file_content: bytes) -> str:
112
+ """
113
+ Extract text from a .docx file.
114
+
115
+ Args:
116
+ file_content: The bytes content of the .docx file
117
+
118
+ Returns:
119
+ The extracted text as a string
120
+ """
121
+ try:
122
+ # Create a temporary file to save the content
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as temp_file:
124
+ temp_file.write(file_content)
125
+ temp_file_path = temp_file.name
126
+
127
+ # Open the temporary docx file and extract text
128
+ doc = docx.Document(temp_file_path)
129
+ full_text = []
130
+ for para in doc.paragraphs:
131
+ full_text.append(para.text)
132
+
133
+ # Clean up the temporary file
134
+ os.unlink(temp_file_path)
135
+
136
+ return "\n".join(full_text)
137
+ except Exception as e:
138
+ logger.error(f"Error extracting text from docx: {str(e)}")
139
+ raise Exception(f"Failed to extract text from docx: {str(e)}")
140
+
141
+
142
+ def extract_text_from_file(file_content: bytes, file_extension: str) -> str:
143
+ """
144
+ Extract text from a file based on its extension.
145
+
146
+ Args:
147
+ file_content: The bytes content of the file
148
+ file_extension: The file extension (.txt, .docx, etc.)
149
+
150
+ Returns:
151
+ The extracted text as a string
152
+ """
153
+ if file_extension.lower() == ".txt":
154
+ # For txt files, simply decode the content
155
+ return file_content.decode("utf-8", errors="replace")
156
+ elif file_extension.lower() == ".docx":
157
+ # For docx files, use the docx extraction function
158
+ return extract_text_from_docx(file_content)
159
+ else:
160
+ raise ValueError(f"Unsupported file extension: {file_extension}")
161
+
162
+
163
+ class SentenceBasedTextSplitter(RecursiveCharacterTextSplitter):
164
+ def __init__(self, chunk_size: int, chunk_overlap: int = 0):
165
+ super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
166
+ self.chunk_size = chunk_size
167
+
168
+ def split_text(self, text: str):
169
+ sentence_endings = re.compile(r"(?<=[.!?])\s+")
170
+ sentences = sentence_endings.split(text)
171
+
172
+ chunks = []
173
+ current_chunk = ""
174
+
175
+ for sentence in sentences:
176
+ if len(current_chunk) + len(sentence) <= self.chunk_size:
177
+ current_chunk += sentence + " "
178
+ else:
179
+ if current_chunk:
180
+ chunks.append(current_chunk.strip())
181
+ current_chunk = sentence + " "
182
+ # Ensure the last chunk includes the remaining sentence if it exists
183
+ if current_chunk:
184
+ chunks.append(current_chunk.strip())
185
+
186
+ return chunks
187
+
188
+
189
+ def split_text(text: str) -> List[str]:
190
+ """
191
+ Split text into chunks of appropriate size for processing.
192
+
193
+ Args:
194
+ text: The full text to split
195
+
196
+ Returns:
197
+ A list of text chunks
198
+ """
199
+ # splitter = RecursiveCharacterTextSplitter(
200
+ # chunk_size=CHUNK_SIZE,
201
+ # chunk_overlap=CHUNK_OVERLAP,
202
+ # length_function=len,
203
+ # is_separator_regex=False,
204
+ # )
205
+ splitter = SentenceBasedTextSplitter(
206
+ chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
207
+ )
208
+ chunks = splitter.split_text(text)
209
+ logger.debug(f"Split text into {len(chunks)} chunks")
210
+ return chunks
211
+
212
+
213
+ def create_grammar_prompt(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> str:
214
+ """
215
+ Create a grammar checking prompt for the given text with proper nouns.
216
+
217
+ Args:
218
+ text: The text to check for grammar issues
219
+ proper_nouns: A string of proper nouns to preserve
220
+
221
+ Returns:
222
+ A formatted prompt string
223
+ """
224
+ return f"""
225
+ Rewrite the provided text to be clear and grammatically correct while preserving technical accuracy. Focus on:
226
+
227
+ 1. Correcting spelling, punctuation, and grammar errors
228
+ 2. Maintaining technical terminology and code snippets
229
+ 3. Ensuring consistent tense, voice, and formatting
230
+ 4. Clarifying function descriptions, parameters, and return values
231
+ 5. Proper use of capitalization, acronyms, and abbreviations
232
+ 6. Improving clarity and conciseness
233
+ 7. Respect markdown and code formatting such as underscores, asterisks, backticks, code blocks, and links
234
+ 8. Ensure proper nouns and acronyms are correctly spelled and capitalized
235
+
236
+ Here's a list of proper nouns and technical terms you should preserve:
237
+ {proper_nouns}
238
+
239
+ Preserve code-specific formatting and syntax. Prioritize original text if unsure about technical terms.
240
+
241
+ Make sure when you show the before vs after text, include a larger phrase or sentence for context.
242
+
243
+ In the response:
244
+ - For 'spelling', 'punctuation', and 'grammar' keys: Provide only changed items with original text, corrected text, and explanation.
245
+
246
+ Ensure that the original text is actually referenced from the given text below:
247
+
248
+ {text}
249
+ """
250
+
251
+
252
+ def process_api_response(content: str) -> Dict[str, List[Dict[str, str]]]:
253
+ """
254
+ Process the API response to extract the JSON result.
255
+
256
+ Args:
257
+ content: The API response content
258
+
259
+ Returns:
260
+ A dictionary with grammar error categories
261
+ """
262
+ # Try to find JSON pattern
263
+ json_start = content.find("{")
264
+ json_end = content.rfind("}") + 1
265
+
266
+ if json_start == -1 or json_end == 0:
267
+ logger.error(f"Could not find JSON in response: {content}")
268
+ raise ValueError("API response did not contain valid JSON")
269
+
270
+ json_str = content[json_start:json_end]
271
+ logger.debug(f"Extracted JSON: {json_str[:100]}...")
272
+
273
+ # Parse the JSON
274
+ try:
275
+ result = json.loads(json_str)
276
+ except json.JSONDecodeError as je:
277
+ logger.error(f"JSON decode error: {str(je)}")
278
+ logger.error(f"JSON string was: {json_str}")
279
+ # Create a default structure
280
+ result = {"spelling": [], "punctuation": [], "grammar": []}
281
+
282
+ return result
283
+
284
+
285
+ def merge_grammar_results(
286
+ results: List[Dict[str, List[Dict[str, str]]]],
287
+ ) -> Dict[str, List[Dict[str, str]]]:
288
+ """
289
+ Merge multiple grammar check results into a single result.
290
+
291
+ Args:
292
+ results: A list of grammar check results
293
+
294
+ Returns:
295
+ A merged grammar check result
296
+ """
297
+ merged = {"spelling": [], "punctuation": [], "grammar": []}
298
+
299
+ for result in results:
300
+ for category in ["spelling", "punctuation", "grammar"]:
301
+ if category in result:
302
+ merged[category].extend(result[category])
303
+
304
+ return merged
305
+
306
+
307
+ def validate_corrections(
308
+ result: Dict[str, List[Dict[str, str]]],
309
+ ) -> Dict[str, List[Dict[str, str]]]:
310
+ """
311
+ Validate grammar corrections to ensure they're meaningful.
312
+
313
+ Args:
314
+ result: The grammar check result
315
+
316
+ Returns:
317
+ Validated grammar check result
318
+ """
319
+ validated = {"spelling": [], "punctuation": [], "grammar": []}
320
+
321
+ for category in ["spelling", "punctuation", "grammar"]:
322
+ for error in result.get(category, []):
323
+ # Skip if before and after are the same
324
+ if error["before"] == error["after"]:
325
+ continue
326
+
327
+ # Skip if only whitespace changes
328
+ if error["before"].strip() == error["after"].strip():
329
+ continue
330
+
331
+ validated[category].append(error)
332
+
333
+ return validated
334
+
335
+
336
+ def apply_corrections(original_text: str, errors: List[Error]) -> str:
337
+ """
338
+ Apply all grammar corrections to the original text.
339
+
340
+ Args:
341
+ original_text: The original text with errors
342
+ errors: List of Error objects with before/after corrections
343
+
344
+ Returns:
345
+ Fully corrected text
346
+ """
347
+ # Process individual errors one at a time
348
+ # Make a copy of the original text
349
+ corrected = original_text
350
+
351
+ # First, find the position of each error in the original text
352
+ error_positions = []
353
+ for error in errors:
354
+ pos = corrected.find(error.before)
355
+ if pos != -1:
356
+ error_positions.append((pos, error))
357
+
358
+ # Sort by position in descending order (to replace from end to start)
359
+ # This way, earlier replacements don't affect positions of later ones
360
+ error_positions.sort(key=lambda x: x[0], reverse=True)
361
+
362
+ # Apply each correction
363
+ for pos, error in error_positions:
364
+ corrected = corrected[:pos] + error.after + corrected[pos + len(error.before) :]
365
+
366
+ return corrected
367
+
368
+
369
+ def check_grammar(text: str, proper_nouns: str = DEFAULT_PROPER_NOUNS) -> Grammar:
370
+ """
371
+ Check the grammar of the given text using LangChain and Azure OpenAI.
372
+
373
+ Args:
374
+ text: The text to check for grammar issues
375
+ proper_nouns: A string of proper nouns to preserve
376
+
377
+ Returns:
378
+ Grammar object containing categorized errors
379
+ """
380
+ try:
381
+ # Split text into chunks if it's too long
382
+ chunks = split_text(text)
383
+
384
+ # Initialize LangChain with Azure OpenAI
385
+
386
+ logger.debug(
387
+ f"Using Azure OpenAI with deployment: {AZURE_OPENAI_DEPLOYMENT_NAME}"
388
+ )
389
+
390
+ # Create system message for JSON format
391
+ system_message = """You are a spellchecker database that outputs grammar errors and corrected text in JSON.
392
+ The JSON object must use the schema that has 'spelling', 'punctuation', and 'grammar' keys, each with a list of objects containing 'before', 'after', and 'explanation'.
393
+ It is strictly imperative that you return as JSON. DO NOT return any other characters other than valid JSON as your response."""
394
+
395
+ # Create a prompt template and chain using the pipe syntax
396
+ prompt = ChatPromptTemplate.from_messages(
397
+ [("system", system_message), ("user", "{prompt}")]
398
+ )
399
+ chain = prompt | llm
400
+
401
+ # Process each chunk in a batch
402
+ logger.debug(f"Processing {len(chunks)} chunks in batch...")
403
+
404
+ if len(chunks) == 1:
405
+ # For single chunks, just use invoke directly
406
+ prompt_text = create_grammar_prompt(chunks[0], proper_nouns)
407
+ response = chain.invoke({"prompt": prompt_text})
408
+ content = response.content
409
+ logger.debug(f"Received single response from API: {content[:200]}...")
410
+ result = process_api_response(content)
411
+ else:
412
+ # For multiple chunks, use batch processing
413
+ prompt_batch = [
414
+ {"prompt": create_grammar_prompt(chunk, proper_nouns)}
415
+ for chunk in chunks
416
+ ]
417
+ responses = chain.batch(prompt_batch)
418
+ logger.debug(f"Received {len(responses)} batch responses from API")
419
+
420
+ # Process each response
421
+ results = []
422
+ for response in responses:
423
+ content = response.content
424
+ result = process_api_response(content)
425
+ results.append(result)
426
+
427
+ # Merge the results
428
+ result = merge_grammar_results(results)
429
+
430
+ # Validate corrections to ensure they're meaningful
431
+ validated_result = validate_corrections(result)
432
+
433
+ # Create Error objects for each category
434
+ spelling_errors = [Error(**err) for err in validated_result.get("spelling", [])]
435
+ punctuation_errors = [
436
+ Error(**err) for err in validated_result.get("punctuation", [])
437
+ ]
438
+ grammar_errors = [Error(**err) for err in validated_result.get("grammar", [])]
439
+
440
+ # Apply corrections to get fully corrected text
441
+ corrected_text = apply_corrections(
442
+ text, spelling_errors + punctuation_errors + grammar_errors
443
+ )
444
+
445
+ # Return a Grammar object
446
+ return Grammar(
447
+ spelling=spelling_errors,
448
+ punctuation=punctuation_errors,
449
+ grammar=grammar_errors,
450
+ file_path="", # Will be updated for file uploads
451
+ corrected_text=corrected_text, # Add the corrected text
452
+ )
453
+
454
+ except Exception as e:
455
+ logger.error(f"Error checking grammar: {str(e)}")
456
+ raise Exception(f"Failed to analyze text: {str(e)}")
457
+
458
+
459
+ def check_grammar_from_file(
460
+ file_content: bytes, filename: str, proper_nouns: str = DEFAULT_PROPER_NOUNS
461
+ ) -> Grammar:
462
+ """
463
+ Check grammar from an uploaded file.
464
+
465
+ Args:
466
+ file_content: The bytes content of the file
467
+ filename: The name of the uploaded file
468
+ proper_nouns: A string of proper nouns to preserve
469
+
470
+ Returns:
471
+ Grammar object containing categorized errors
472
+ """
473
+ try:
474
+ _, file_extension = os.path.splitext(filename)
475
+ text = extract_text_from_file(file_content, file_extension)
476
+
477
+ # Check grammar on the extracted text
478
+ grammar_result = check_grammar(text, proper_nouns)
479
+
480
+ # Update the file path
481
+ grammar_result.file_path = filename
482
+
483
+ return grammar_result
484
+ except Exception as e:
485
+ logger.error(f"Error checking grammar from file: {str(e)}")
486
+ raise Exception(f"Failed to analyze file: {str(e)}")
487
+
488
+
489
+ def display_results(response: Grammar, path: str = "", repo_link: str = "") -> int:
490
+ """
491
+ Display the grammar check results using Rich.
492
+
493
+ Args:
494
+ response: The Grammar object with check results
495
+ path: Path to the file that was checked
496
+ repo_link: Optional repository link (for GitHub URLs)
497
+
498
+ Returns:
499
+ Total number of errors found
500
+ """
501
+ # Replace local file path with GitHub URL if repo_link is provided
502
+ if repo_link and response.file_path:
503
+ # Use os.path.split to handle path separators correctly
504
+ parts = os.path.normpath(response.file_path).split(os.path.sep)
505
+ relative_path = os.path.basename(response.file_path)
506
+ path = f"{repo_link.rstrip('/')}/blob/main/{relative_path}"
507
+ elif path:
508
+ # Use the provided path
509
+ pass
510
+ elif response.file_path:
511
+ # Use the file path from the response
512
+ path = response.file_path
513
+ else:
514
+ # Default text
515
+ path = "Text input"
516
+
517
+ # Print the file path
518
+ console.print(f"\n[bold cyan]File: {path}[/bold cyan]")
519
+
520
+ total_errors = 0
521
+
522
+ # Display each error category
523
+ for category in ["spelling", "punctuation", "grammar"]:
524
+ table = Table(title=f"{category.capitalize()} Corrections", box=ROUNDED)
525
+ table.add_column("Original", justify="left", style="bold red")
526
+ table.add_column("Corrected", justify="left", style="bold green")
527
+ table.add_column("Explanation", justify="left", style="italic")
528
+
529
+ errors = getattr(response, category)
530
+ for error in errors:
531
+ if error.before != error.after:
532
+ table.add_row(error.before, error.after, error.explanation)
533
+ table.add_row("", "", "") # Add an empty row for spacing
534
+ total_errors += 1
535
+
536
+ if errors:
537
+ console.print(table)
538
+ else:
539
+ no_errors_msg = f"No {category} errors found."
540
+ console.print(f"[blue]{no_errors_msg}[/blue]")
541
+
542
+ console.print(
543
+ f"[bold {'green' if total_errors == 0 else 'red'}]Total errors found: {total_errors}[/bold]"
544
+ )
545
+
546
+ # Display corrected text if available
547
+ if response.corrected_text:
548
+ console.print("\n[bold cyan]Fully Corrected Text:[/bold cyan]")
549
+ console.print(response.corrected_text)
550
+
551
+ # Write output to a file
552
+ with open("grammar_results.txt", "w") as f:
553
+ f.write(console.export_text())
554
+
555
+ return total_errors
models.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional
3
+
4
+ class Error(BaseModel):
5
+ """Model for individual grammar errors."""
6
+ before: str
7
+ after: str
8
+ explanation: str
9
+
10
+ class Grammar(BaseModel):
11
+ """Model for grammar check results."""
12
+ spelling: List[Error]
13
+ punctuation: List[Error]
14
+ grammar: List[Error]
15
+ file_path: Optional[str] = None
16
+ corrected_text: Optional[str] = None