File size: 7,147 Bytes
b052258
 
 
 
 
 
 
 
 
 
 
 
 
69fcbd3
 
 
b052258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from flask import Flask, request, jsonify
from flask_cors import CORS
import joblib
import pandas as pd
import numpy as np
from feature_extractor_web import extract_features_web
import logging
import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")

app = Flask(__name__)
default_origins = "https://alalayfe.vercel.app,https://www.alalayfe.vercel.app,http://localhost:3000"
ALLOWED_ORIGINS = [origin.strip().rstrip("/") for origin in os.getenv('ALLOWED_ORIGINS', default_origins).split(',') if origin.strip()]
CORS(app, origins=ALLOWED_ORIGINS)
logger = logging.getLogger(__name__)

@app.route('/')
def home():
    return jsonify({
        'service': 'Alalay Readability API',
        'status': 'running',
        'endpoints': [
            '/health',
            '/api/predict',
            '/api/predict/batch'
        ]
    }), 200

# Load all the saved components
print("Loading model components...")
try:
    model = joblib.load(os.path.join(MODELS_DIR, "readability_model.pkl"))
    label_encoder = joblib.load(os.path.join(MODELS_DIR, "label_encoder.pkl"))
    grade_mapping = joblib.load(os.path.join(MODELS_DIR, "grade_mapping.pkl"))
    thresholds = joblib.load(os.path.join(MODELS_DIR, "thresholds.pkl"))
    feature_info = joblib.load(os.path.join(MODELS_DIR, "feature_info.pkl"))
    print("All components loaded successfully!")
    print(f"   Model type: {type(model.named_steps['classifier']).__name__}")
    print(f"   Classes: {label_encoder.classes_}")
except Exception as e:
    print(f"Error loading models: {e}")
    model = None


def build_features_df(features: dict) -> pd.DataFrame:
    """Build a model-ready DataFrame with the same feature order used in training."""
    all_features = feature_info.get('all_features', list(features.keys()))
    categorical_cols = set(feature_info.get('categorical_cols', []))

    row = {}
    for col in all_features:
        if col in features:
            row[col] = features[col]
        elif col in categorical_cols:
            row[col] = 'Other'
        else:
            row[col] = 0.0

    return pd.DataFrame([row], columns=all_features)


def pick_class_with_thresholds(probabilities: np.ndarray) -> int:
    """Use thresholds when available, otherwise fall back to argmax probability."""
    classes = label_encoder.classes_
    base_idx = int(np.argmax(probabilities))

    eligible = [
        i for i, class_name in enumerate(classes)
        if probabilities[i] >= thresholds.get(class_name, 0.5)
    ]
    if not eligible:
        return base_idx

    return max(eligible, key=lambda i: probabilities[i])

# MongoDB connection (optional for now)
# try:
#     MONGO_URI = os.getenv('MONGO_URI', 'mongodb://localhost:27017/')
#     client = MongoClient(MONGO_URI)
#     db = client['readability_db']
#     # Test connection
#     client.admin.command('ping')
#     print("MongoDB connected")
# except:
#     print("MongoDB not available - continuing without database")
#     db = None

@app.route('/health', methods=['GET'])
@app.route('/api/health', methods=['GET'])
def health():
    model_name = type(model.named_steps['classifier']).__name__ if model is not None else None
    classes = label_encoder.classes_.tolist() if model is not None else []
    return jsonify({
        'status': 'healthy' if model is not None else 'degraded',
        'model': model_name,
        'classes': classes
    }), 200

@app.route('/api/predict', methods=['POST'])
def predict():
    if model is None:
        return jsonify({'error': 'Model not loaded. Check server logs.'}), 503
    try:
        data = request.get_json()
        text = data.get('text', '').strip()
        
        if not text:
            return jsonify({'error': 'No text provided'}), 400
        
        if len(text) < 10:
            return jsonify({'error': 'Text must be at least 10 characters'}), 400
        
        # Extract features
        features = extract_features_web(text)
        if not features:
            return jsonify({'error': 'Feature extraction failed. Check server logs.'}), 500

        # Convert to DataFrame using training-time feature order.
        features_df = build_features_df(features)
        
        # Get prediction
        probabilities = model.predict_proba(features_df)[0]
        
        # Apply threshold tuning with deterministic tie handling.
        final_prediction = pick_class_with_thresholds(probabilities)
        
        predicted_class = label_encoder.classes_[final_prediction]
        grade_level = grade_mapping.get(predicted_class, predicted_class)
        
        # Prepare response
        response = {
            'success': True,
            'text': text[:200] + '...' if len(text) > 200 else text,
            'prediction': {
                'predicted_class': predicted_class,
                'grade_level': grade_level,
                'confidences': {
                    class_name: float(probabilities[i])
                    for i, class_name in enumerate(label_encoder.classes_)
                }
            },
            'features': {k: float(v) if isinstance(v, (int, float)) else v 
                        for k, v in features.items()}
        }
        
        # # Store in MongoDB if available
        # if db:
        #     db.texts.insert_one({
        #         'text': text,
        #         'prediction': response['prediction'],
        #         'features': features,
        #         'timestamp': datetime.utcnow()
        #     })
        
        return jsonify(response), 200
        
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/api/predict/batch', methods=['POST'])
def batch_predict():
    try:
        data = request.get_json()
        texts = data.get('texts', [])
        
        if not texts:
            return jsonify({'error': 'No texts provided'}), 400
        
        results = []
        for text in texts:
            features = extract_features_web(text)
            features_df = build_features_df(features)
            probabilities = model.predict_proba(features_df)[0]
            prediction = pick_class_with_thresholds(probabilities)
            
            predicted_class = label_encoder.classes_[prediction]
            
            results.append({
                'text': text[:100] + '...' if len(text) > 100 else text,
                'prediction': {
                    'class': predicted_class,
                    'grade': grade_mapping.get(predicted_class, predicted_class),
                    'confidences': {
                        class_name: float(probabilities[i])
                        for i, class_name in enumerate(label_encoder.classes_)
                    }
                }
            })
        
        return jsonify({
            'success': True,
            'count': len(results),
            'results': results
        }), 200
        
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Hugging Face Spaces uses port 7860
    port = int(os.getenv('PORT', 7860))
    app.run(host='0.0.0.0', port=port)