swe-cefr-sp / web_app /app.py
fffffwl's picture
Initial HF Space for Swedish CEFR web app
0b8530c
"""
CEFR Sentence Level Assessment Web Application
Flask-based web interface for assessing Swedish text at sentence level
"""
import os
from pathlib import Path
from flask import Flask, render_template, request, jsonify
from model import CEFRModel, assess_text
# Initialize Flask app
app = Flask(__name__)
app.config['SECRET_KEY'] = 'cefr-assessment-app'
# Initialize model
print("Loading CEFR assessment model...")
model_path = os.environ.get('MODEL_PATH', 'runs/metric-proto-k3/metric_proto.pt')
model = CEFRModel(model_path=model_path)
print(f"Model loaded successfully! Using device: {model.device}")
# CEFR level styles for HTML display
CEFR_STYLES = {
'A1': {'color': '#E74C3C', 'name': 'A1 - Beginner'},
'A2': {'color': '#E67E22', 'name': 'A2 - Elementary'},
'B1': {'color': '#F39C12', 'name': 'B1 - Intermediate'},
'B2': {'color': '#27AE60', 'name': 'B2 - Upper Intermediate'},
'C1': {'color': '#3498DB', 'name': 'C1 - Advanced'},
'C2': {'color': '#9B59B6', 'name': 'C2 - Proficient'},
}
@app.route('/')
def index():
"""Home page with text input form"""
return render_template('index.html')
@app.route('/assess', methods=['POST'])
def assess():
"""Assess text and return results"""
try:
# Get text from form
data = request.get_json()
text = data.get('text', '').strip()
if not text:
return jsonify({'error': 'Please enter some text to assess'}), 400
# Limit text length
if len(text) > 50000: # ~50KB limit
return jsonify({'error': 'Text is too long. Please limit to 50,000 characters.'}), 400
# Assess text
results = assess_text(text, model)
if not results:
return jsonify({'error': 'No valid sentences found in the text'}), 400
# Prepare response
response = {
'results': results,
'cefr_styles': CEFR_STYLES,
'stats': compute_stats(results)
}
return jsonify(response)
except Exception as e:
print(f"Error in assessment: {str(e)}")
return jsonify({'error': f'An error occurred during assessment: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST'])
def api_predict():
"""API endpoint for batch predictions"""
try:
data = request.get_json()
sentences = data.get('sentences', [])
if not sentences:
return jsonify({'error': 'No sentences provided'}), 400
if not isinstance(sentences, list):
return jsonify({'error': 'Sentences must be a list'}), 400
# Limit batch size
if len(sentences) > 100:
return jsonify({'error': 'Batch size limited to 100 sentences'}), 400
# Predict
predictions = model.predict_batch(sentences)
# Format response
results = []
for sent, (level, confidence) in zip(sentences, predictions):
results.append({
'sentence': sent,
'level': level,
'confidence': confidence
})
return jsonify({
'predictions': results,
'count': len(results)
})
except Exception as e:
print(f"Error in API prediction: {str(e)}")
return jsonify({'error': str(e)}), 500
def compute_stats(results: list) -> dict:
"""Compute statistics about the assessment results"""
if not results:
return {}
# Count levels
level_counts = {}
for item in results:
level = item['level']
level_counts[level] = level_counts.get(level, 0) + 1
# Average confidence
avg_confidence = sum(item['confidence'] for item in results) / len(results)
# Most common level
if level_counts:
most_common = max(level_counts, key=level_counts.get)
most_common_count = level_counts[most_common]
most_common_pct = (most_common_count / len(results)) * 100
else:
most_common = None
most_common_count = 0
most_common_pct = 0
return {
'total_sentences': len(results),
'level_distribution': level_counts,
'avg_confidence': avg_confidence,
'most_common_level': {
'level': most_common,
'count': most_common_count,
'percentage': round(most_common_pct, 1)
}
}
@app.context_processor
def utility_processor():
"""Utility functions for Jinja templates"""
return dict(
round=round,
len=len
)
if __name__ == '__main__':
# Create uploads directory
os.makedirs('uploads', exist_ok=True)
print("Starting CEFR Assessment Web App...")
print(f"\nModel path: {model_path}")
print(f"Model device: {model.device}")
print("\nStarting Flask server...")
# Run app
app.run(
debug=True,
host='0.0.0.0',
port=5000,
threaded=True
)