aes-bert / app.py
Hak978's picture
Update app.py
66c7c07 verified
# from flask import Flask, render_template, request
# from transformers import BertForSequenceClassification, BertTokenizer
# from language_tool_python import LanguageTool
# from spellchecker import SpellChecker
# from collections import Counter
# import string
# import torch
# import numpy as np
# import os
# from pathlib import Path
# app = Flask(__name__)
# # Configure cache directories
# cache_base = os.getenv('XDG_CACHE_HOME', '/tmp/cache')
# huggingface_cache = os.path.join(cache_base, 'huggingface')
# languagetool_cache = os.path.join(cache_base, 'languagetool')
# # Ensure directories exist
# Path(huggingface_cache).mkdir(parents=True, exist_ok=True)
# Path(languagetool_cache).mkdir(parents=True, exist_ok=True)
# # Initialize LanguageTool
# try:
# grammar_tool = LanguageTool(
# 'en-US',
# remote_server='https://api.languagetool.org'
# )
# print("LanguageTool initialized successfully")
# except Exception as e:
# print(f"Error initializing LanguageTool: {e}")
# grammar_tool = None
# # Initialize SpellChecker
# spell = SpellChecker()
# # Load Hugging Face models
# MODEL_NAME = "Hak978/aes-bert-models"
# try:
# # Load models using subfolder parameter
# model_website1 = BertForSequenceClassification.from_pretrained(
# MODEL_NAME,
# subfolder="essay_scoring_model_regression_20240228_123826",
# cache_dir=huggingface_cache
# )
# model_website2 = BertForSequenceClassification.from_pretrained(
# MODEL_NAME,
# subfolder="essay_scoring_model_regression_20240229_133324",
# cache_dir=huggingface_cache
# )
# # Load tokenizer
# tokenizer = BertTokenizer.from_pretrained(
# 'bert-base-uncased',
# cache_dir=huggingface_cache
# )
# print("Models loaded successfully")
# except Exception as e:
# print(f"Error loading models: {e}")
# model_website1 = model_website2 = tokenizer = None
# def check_spelling(text):
# words = text.split()
# misspelled = spell.unknown(words)
# return list(misspelled)
# def check_grammar(text):
# if grammar_tool is None:
# return []
# matches = grammar_tool.check(text)
# return [{'message': match.message, 'replacements': match.replacements} for match in matches]
# def count_words(text):
# words = text.split()
# return len(words)
# def calculate_sentence_lengths(text):
# sentences = text.split('.')
# lengths = [len(sentence.split()) for sentence in sentences if sentence.strip()]
# return {
# 'average': np.mean(lengths) if lengths else 0,
# 'min': min(lengths) if lengths else 0,
# 'max': max(lengths) if lengths else 0
# }
# def calculate_vocabulary_diversity(text):
# words = text.lower().split()
# unique_words = set(words)
# return len(unique_words) / len(words) if words else 0
# def count_punctuation(text):
# return sum(1 for char in text if char in string.punctuation)
# def predict_score(text, model, tokenizer):
# if model is None or tokenizer is None:
# return 0.0
# inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
# with torch.no_grad():
# outputs = model(**inputs)
# predictions = outputs.logits
# return predictions.item()
# @app.route('/')
# def home():
# return render_template('index.html')
# @app.route('/health')
# def health_check():
# return {'status': 'healthy'}, 200
# @app.route('/analyze', methods=['POST'])
# def analyze():
# if request.method == 'POST':
# essay_text = request.form['essay']
# feedback = {
# 'word_count': count_words(essay_text),
# 'spelling_errors': check_spelling(essay_text),
# 'grammar_errors': check_grammar(essay_text),
# }
# if model_website1 and model_website2 and tokenizer:
# score1 = predict_score(essay_text, model_website1, tokenizer)
# score2 = predict_score(essay_text, model_website2, tokenizer)
# feedback.update({
# 'score1': round(score1, 2),
# 'score2': round(score2, 2),
# 'average_score': round((score1 + score2) / 2, 2)
# })
# sentence_stats = calculate_sentence_lengths(essay_text)
# feedback.update({
# 'avg_sentence_length': round(sentence_stats['average'], 2),
# 'min_sentence_length': int(sentence_stats['min']),
# 'max_sentence_length': int(sentence_stats['max']),
# 'vocabulary_diversity': round(calculate_vocabulary_diversity(essay_text) * 100, 2),
# 'punctuation_count': count_punctuation(essay_text)
# })
# return render_template('result.html', feedback=feedback)
# if __name__ == '__main__':
# port = int(os.environ.get('PORT', 7860)) # Hugging Face uses 7860
# app.run(host='0.0.0.0', port=port)
from flask import Flask, render_template, request
from transformers import BertForSequenceClassification, BertTokenizer
from language_tool_python import LanguageTool
from spellchecker import SpellChecker
from collections import Counter
import string
import torch
import numpy as np
import os
from pathlib import Path
app = Flask(__name__, template_folder='.')
# Configure cache directories
cache_base = os.getenv('XDG_CACHE_HOME', '/tmp/cache')
huggingface_cache = os.path.join(cache_base, 'huggingface')
languagetool_cache = os.path.join(cache_base, 'languagetool')
# Ensure directories exist
Path(huggingface_cache).mkdir(parents=True, exist_ok=True)
Path(languagetool_cache).mkdir(parents=True, exist_ok=True)
# Initialize LanguageTool
try:
grammar_tool = LanguageTool(
'en-US',
remote_server='https://api.languagetool.org'
)
print("LanguageTool initialized successfully")
except Exception as e:
print(f"Error initializing LanguageTool: {e}")
grammar_tool = None
# Initialize SpellChecker
spell = SpellChecker()
# Load Hugging Face models
MODEL_NAME = "Hak978/aes-bert-models"
try:
# Load models using subfolder parameter
model_website1 = BertForSequenceClassification.from_pretrained(
MODEL_NAME,
subfolder="essay_scoring_model_regression_20240228_123826",
cache_dir=huggingface_cache
)
model_website2 = BertForSequenceClassification.from_pretrained(
MODEL_NAME,
subfolder="essay_scoring_model_regression_20240229_133324",
cache_dir=huggingface_cache
)
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased',
cache_dir=huggingface_cache
)
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
model_website1 = model_website2 = tokenizer = None
def tokenize_text(text, tokenizer):
tokens = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
truncation=True,
return_token_type_ids=False,
padding='max_length',
return_attention_mask=True,
return_tensors='pt'
)
return tokens['input_ids'], tokens['attention_mask']
def normalize_bert_score(raw_score, category, essay):
params = {
'grammar': {'min': 1, 'max': 8, 'threshold': 0.8},
'lexical': {'min': 1, 'max': 8, 'threshold': 0.8},
'global_organization': {'min': 3, 'max': 8, 'threshold': 0.6},
'local_organization': {'min': 3, 'max': 8, 'threshold': 0.6},
'supporting_ideas': {'min': 3, 'max': 8, 'threshold': 0.6},
'holistic': {'min': 1, 'max': 5, 'threshold': 0.9}
}
category_params = params[category]
error_count = len(grammar_tool.check(essay)) if grammar_tool else 0
words = essay.split()
spelling_errors = len(spell.unknown(words)) if spell else 0
error_density = (error_count + spelling_errors) / len(words) if words else 1
penalty = error_density * 7
base_score = category_params['min'] + (raw_score * (category_params['max'] - category_params['min']))
if category in ['grammar', 'lexical', 'holistic']:
base_score = max(category_params['min'], base_score - penalty)
return round(max(category_params['min'], min(category_params['max'], base_score)), 1)
def get_predictions_website1(essays):
if not model_website1 or not tokenizer:
return []
input_ids = []
attention_masks = []
for essay in essays:
tokens = tokenize_text(essay, tokenizer)
input_ids.append(tokens[0])
attention_masks.append(tokens[1])
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
model_website1.eval()
with torch.no_grad():
outputs = model_website1(input_ids, attention_mask=attention_masks)
raw_predictions = outputs.logits.cpu().numpy()
normalized_predictions = []
categories = ['grammar', 'lexical', 'global_organization',
'local_organization', 'supporting_ideas', 'holistic']
for raw_pred in raw_predictions:
raw_scores = 1 / (1 + np.exp(-raw_pred))
norm_pred = [
normalize_bert_score(score, category, essays[0])
for score, category in zip(raw_scores, categories)
]
normalized_predictions.append(norm_pred)
return normalized_predictions
def calculate_grammar_score(essay):
if not grammar_tool:
return None
matches = grammar_tool.check(essay)
error_weights = {
'SPELLING': 2.0,
'GRAMMAR': 2.5,
'PUNCTUATION': 1.5,
'TYPOGRAPHY': 1.0
}
weighted_errors = 0
for match in matches:
weight = error_weights.get(match.category, 1.5)
weighted_errors += weight
words = len(essay.split())
error_density = (weighted_errors / words) * 100 if words > 0 else 100
base_score = 10 - (error_density * 0.7)
error_types = Counter(match.category for match in matches)
repeated_error_penalty = sum(count * 0.3 for count in error_types.values() if count > 2)
final_score = base_score - repeated_error_penalty
return round(max(2, min(10, final_score)), 1)
def calculate_spelling_score(essay):
words = [word.strip('.,!?()[]{}":;') for word in essay.split()]
misspelled = spell.unknown(words) if spell else []
total_words = len(words)
error_count = len(misspelled)
error_rate = error_count / total_words if total_words > 0 else 1
error_penalty = error_rate * 20
if error_count > 5:
error_penalty += (error_count - 5) * 0.5
spelling_score = 10 - error_penalty
return round(max(2, min(10, spelling_score)), 1)
def calculate_word_diversity(essay):
words = essay.lower().translate(str.maketrans('', '', string.punctuation)).split()
if not words:
return 7.0
misspelled = spell.unknown(words) if spell else []
spelling_penalty = len(misspelled) / len(words) * 5
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
content_words = [word for word in words if word not in stop_words]
if not content_words:
return 7.0
total_words = len(content_words)
unique_words = len(set(content_words))
word_freq = Counter(content_words)
repeated_words = sum(1 for count in word_freq.values() if count > 2)
diversity_ratio = unique_words / total_words
repetition_penalty = min(1.5, repeated_words / unique_words)
base_score = 8 + (2 * diversity_ratio)
final_score = base_score - repetition_penalty - spelling_penalty
return round(max(5, min(10, final_score)), 1)
@app.route('/', methods=['GET', 'POST'])
def index():
context = {
'essay': '',
'grammar_score': None,
'lexical_score': None,
'global_organization_score': None,
'local_organization_score': None,
'supporting_ideas_score': None,
'holistic_score': None,
'grammar_score2': None,
'spelling_score': None,
'word_diversity_score': None,
'essay_quality_score': None
}
if request.method == 'POST':
essay = request.form['essay']
context['essay'] = essay
# Website 1 predictions
predictions_website1 = get_predictions_website1([essay])
if predictions_website1 and len(predictions_website1[0]) >= 6:
context.update({
'grammar_score': predictions_website1[0][0],
'lexical_score': predictions_website1[0][1],
'global_organization_score': predictions_website1[0][2],
'local_organization_score': predictions_website1[0][3],
'supporting_ideas_score': predictions_website1[0][4],
'holistic_score': min(5.0, predictions_website1[0][5])
})
# Website 2 predictions
context['grammar_score2'] = calculate_grammar_score(essay)
context['spelling_score'] = calculate_spelling_score(essay)
context['word_diversity_score'] = calculate_word_diversity(essay)
# Calculate overall quality score
if context['holistic_score'] and context['grammar_score2']:
context['essay_quality_score'] = round(
(context['holistic_score'] * 2 + context['grammar_score2']) / 3,
1
)
return render_template('index.html', **context)
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)