File size: 7,147 Bytes
b052258 69fcbd3 b052258 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | from flask import Flask, request, jsonify
from flask_cors import CORS
import joblib
import pandas as pd
import numpy as np
from feature_extractor_web import extract_features_web
import logging
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
app = Flask(__name__)
default_origins = "https://alalayfe.vercel.app,https://www.alalayfe.vercel.app,http://localhost:3000"
ALLOWED_ORIGINS = [origin.strip().rstrip("/") for origin in os.getenv('ALLOWED_ORIGINS', default_origins).split(',') if origin.strip()]
CORS(app, origins=ALLOWED_ORIGINS)
logger = logging.getLogger(__name__)
@app.route('/')
def home():
return jsonify({
'service': 'Alalay Readability API',
'status': 'running',
'endpoints': [
'/health',
'/api/predict',
'/api/predict/batch'
]
}), 200
# Load all the saved components
print("Loading model components...")
try:
model = joblib.load(os.path.join(MODELS_DIR, "readability_model.pkl"))
label_encoder = joblib.load(os.path.join(MODELS_DIR, "label_encoder.pkl"))
grade_mapping = joblib.load(os.path.join(MODELS_DIR, "grade_mapping.pkl"))
thresholds = joblib.load(os.path.join(MODELS_DIR, "thresholds.pkl"))
feature_info = joblib.load(os.path.join(MODELS_DIR, "feature_info.pkl"))
print("All components loaded successfully!")
print(f" Model type: {type(model.named_steps['classifier']).__name__}")
print(f" Classes: {label_encoder.classes_}")
except Exception as e:
print(f"Error loading models: {e}")
model = None
def build_features_df(features: dict) -> pd.DataFrame:
"""Build a model-ready DataFrame with the same feature order used in training."""
all_features = feature_info.get('all_features', list(features.keys()))
categorical_cols = set(feature_info.get('categorical_cols', []))
row = {}
for col in all_features:
if col in features:
row[col] = features[col]
elif col in categorical_cols:
row[col] = 'Other'
else:
row[col] = 0.0
return pd.DataFrame([row], columns=all_features)
def pick_class_with_thresholds(probabilities: np.ndarray) -> int:
"""Use thresholds when available, otherwise fall back to argmax probability."""
classes = label_encoder.classes_
base_idx = int(np.argmax(probabilities))
eligible = [
i for i, class_name in enumerate(classes)
if probabilities[i] >= thresholds.get(class_name, 0.5)
]
if not eligible:
return base_idx
return max(eligible, key=lambda i: probabilities[i])
# MongoDB connection (optional for now)
# try:
# MONGO_URI = os.getenv('MONGO_URI', 'mongodb://localhost:27017/')
# client = MongoClient(MONGO_URI)
# db = client['readability_db']
# # Test connection
# client.admin.command('ping')
# print("MongoDB connected")
# except:
# print("MongoDB not available - continuing without database")
# db = None
@app.route('/health', methods=['GET'])
@app.route('/api/health', methods=['GET'])
def health():
model_name = type(model.named_steps['classifier']).__name__ if model is not None else None
classes = label_encoder.classes_.tolist() if model is not None else []
return jsonify({
'status': 'healthy' if model is not None else 'degraded',
'model': model_name,
'classes': classes
}), 200
@app.route('/api/predict', methods=['POST'])
def predict():
if model is None:
return jsonify({'error': 'Model not loaded. Check server logs.'}), 503
try:
data = request.get_json()
text = data.get('text', '').strip()
if not text:
return jsonify({'error': 'No text provided'}), 400
if len(text) < 10:
return jsonify({'error': 'Text must be at least 10 characters'}), 400
# Extract features
features = extract_features_web(text)
if not features:
return jsonify({'error': 'Feature extraction failed. Check server logs.'}), 500
# Convert to DataFrame using training-time feature order.
features_df = build_features_df(features)
# Get prediction
probabilities = model.predict_proba(features_df)[0]
# Apply threshold tuning with deterministic tie handling.
final_prediction = pick_class_with_thresholds(probabilities)
predicted_class = label_encoder.classes_[final_prediction]
grade_level = grade_mapping.get(predicted_class, predicted_class)
# Prepare response
response = {
'success': True,
'text': text[:200] + '...' if len(text) > 200 else text,
'prediction': {
'predicted_class': predicted_class,
'grade_level': grade_level,
'confidences': {
class_name: float(probabilities[i])
for i, class_name in enumerate(label_encoder.classes_)
}
},
'features': {k: float(v) if isinstance(v, (int, float)) else v
for k, v in features.items()}
}
# # Store in MongoDB if available
# if db:
# db.texts.insert_one({
# 'text': text,
# 'prediction': response['prediction'],
# 'features': features,
# 'timestamp': datetime.utcnow()
# })
return jsonify(response), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/predict/batch', methods=['POST'])
def batch_predict():
try:
data = request.get_json()
texts = data.get('texts', [])
if not texts:
return jsonify({'error': 'No texts provided'}), 400
results = []
for text in texts:
features = extract_features_web(text)
features_df = build_features_df(features)
probabilities = model.predict_proba(features_df)[0]
prediction = pick_class_with_thresholds(probabilities)
predicted_class = label_encoder.classes_[prediction]
results.append({
'text': text[:100] + '...' if len(text) > 100 else text,
'prediction': {
'class': predicted_class,
'grade': grade_mapping.get(predicted_class, predicted_class),
'confidences': {
class_name: float(probabilities[i])
for i, class_name in enumerate(label_encoder.classes_)
}
}
})
return jsonify({
'success': True,
'count': len(results),
'results': results
}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Hugging Face Spaces uses port 7860
port = int(os.getenv('PORT', 7860))
app.run(host='0.0.0.0', port=port) |