Update app.py
Browse files
app.py
CHANGED
|
@@ -6,15 +6,18 @@ from sentence_transformers import SentenceTransformer, util
|
|
| 6 |
from PIL import Image
|
| 7 |
from typing import List
|
| 8 |
import torch
|
| 9 |
-
from transformers import BertTokenizer, BertModel
|
| 10 |
import torch.nn.functional as F
|
| 11 |
-
import language_tool_python # Import LanguageTool for grammar checking
|
| 12 |
|
| 13 |
# Load pre-trained models
|
| 14 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 15 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
| 16 |
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Initialize Groq client
|
| 19 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
| 20 |
|
|
@@ -24,9 +27,6 @@ system_prompt = {
|
|
| 24 |
"content": "You are a useful assistant. You reply with efficient answers."
|
| 25 |
}
|
| 26 |
|
| 27 |
-
# Initialize grammar checker
|
| 28 |
-
tool = language_tool_python.LanguageTool('en-US')
|
| 29 |
-
|
| 30 |
async def chat_groq(message, history):
|
| 31 |
messages = [system_prompt]
|
| 32 |
for msg in history:
|
|
@@ -103,13 +103,22 @@ def calculate_sentence_similarity(text1, text2):
|
|
| 103 |
embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
|
| 104 |
return util.pytorch_cos_sim(embedding1, embedding2).item()
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
# Apply a penalty based on the number of grammar errors
|
| 112 |
-
penalty =
|
| 113 |
return penalty
|
| 114 |
|
| 115 |
def compare_answers(student_answer, teacher_answer):
|
|
@@ -120,7 +129,7 @@ def compare_answers(student_answer, teacher_answer):
|
|
| 120 |
semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
|
| 121 |
|
| 122 |
# Apply grammar penalty
|
| 123 |
-
grammar_penalty =
|
| 124 |
final_similarity = semantic_similarity * grammar_penalty
|
| 125 |
|
| 126 |
return final_similarity
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
from typing import List
|
| 8 |
import torch
|
| 9 |
+
from transformers import BertTokenizer, BertModel, T5ForConditionalGeneration, T5Tokenizer
|
| 10 |
import torch.nn.functional as F
|
|
|
|
| 11 |
|
| 12 |
# Load pre-trained models
|
| 13 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 14 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
| 15 |
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 16 |
|
| 17 |
+
# Load the pre-trained T5 model and tokenizer for grammar error detection
|
| 18 |
+
grammar_model = T5ForConditionalGeneration.from_pretrained('t5-base')
|
| 19 |
+
grammar_tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
| 20 |
+
|
| 21 |
# Initialize Groq client
|
| 22 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
| 23 |
|
|
|
|
| 27 |
"content": "You are a useful assistant. You reply with efficient answers."
|
| 28 |
}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
async def chat_groq(message, history):
|
| 31 |
messages = [system_prompt]
|
| 32 |
for msg in history:
|
|
|
|
| 103 |
embedding2 = sentence_model.encode(text2, convert_to_tensor=True)
|
| 104 |
return util.pytorch_cos_sim(embedding1, embedding2).item()
|
| 105 |
|
| 106 |
+
# Grammar detection and penalization using T5 model
|
| 107 |
+
def detect_grammar_errors(text):
|
| 108 |
+
input_text = f"grammar: {text}"
|
| 109 |
+
inputs = grammar_tokenizer.encode(input_text, return_tensors='pt', max_length=512, truncation=True)
|
| 110 |
+
outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
|
| 111 |
+
grammar_analysis = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 112 |
+
|
| 113 |
+
# Count the number of errors based on specific indicators (customize based on analysis)
|
| 114 |
+
error_count = grammar_analysis.count('error') # Use your own criteria
|
| 115 |
+
return error_count
|
| 116 |
+
|
| 117 |
+
def penalize_for_grammar(student_answer):
|
| 118 |
+
grammar_errors = detect_grammar_errors(student_answer)
|
| 119 |
|
| 120 |
+
# Apply a penalty based on the number of grammar errors (max 50% penalty)
|
| 121 |
+
penalty = max(0.5, 1 - 0.05 * grammar_errors)
|
| 122 |
return penalty
|
| 123 |
|
| 124 |
def compare_answers(student_answer, teacher_answer):
|
|
|
|
| 129 |
semantic_similarity = (0.75 * bert_similarity + 0.25 * sentence_similarity)
|
| 130 |
|
| 131 |
# Apply grammar penalty
|
| 132 |
+
grammar_penalty = penalize_for_grammar(student_answer)
|
| 133 |
final_similarity = semantic_similarity * grammar_penalty
|
| 134 |
|
| 135 |
return final_similarity
|