File size: 2,383 Bytes
ad19081 5266460 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | # !usr/bin/env/python3
import sys
from functools import lru_cache
from pathlib import Path
from typing import Any
from flask import Flask, jsonify, render_template, request
BASE_DIR = Path(__file__).resolve().parent
SRC_DIR = BASE_DIR / "src"
SAVED_DIR = BASE_DIR / "saved"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
from inference import Inference
from helpers import load_json
app = Flask(__name__)
def _compute_model_metrics(metrics_payload: dict[str, Any]) -> dict[str, float]:
test_metrics = metrics_payload.get("test") or {}
tn = float(test_metrics.get("tn", 0.0))
fp = float(test_metrics.get("fp", 0.0))
tp = float(test_metrics.get("tp", 0.0))
fn = float(test_metrics.get("fn", 0.0))
specificity = tn / (tn + fp) if (tn + fp) else 0.0
sensitivity = tp / (tp + fn) if (tp + fn) else float(test_metrics.get("recall", 0.0))
youden_j = sensitivity + specificity - 1.0
return {
"f1": float(test_metrics.get("f1", 0.0)),
"youden_j": round(youden_j, 5),
"auc_roc": float(test_metrics.get("roc_auc", 0.0)),
}
@lru_cache(maxsize=1)
def get_metrics() -> dict[str, float]:
return _compute_model_metrics(load_json(SAVED_DIR / "model" / "metrics.json"))
@lru_cache(maxsize=1)
def get_service() -> Inference:
return Inference(project_root=BASE_DIR)
def predict(text1: str, text2: str) -> dict[str, Any]:
return get_service().predict(text1, text2).to_dict()
@app.route("/", methods=["GET"])
def home_route():
return render_template("index.html")
@app.route("/predict", methods=["POST"])
def predict_route():
data = request.get_json(force=True)
text1 = (data.get("text1") or "").strip()
text2 = (data.get("text2") or "").strip()
if not text1 or not text2: return jsonify({"error": "Both text fields are required."}), 400
try: result = predict(text1, text2)
except Exception as exc:
return jsonify({"error": f"Inference failed: {exc}"}), 500
return jsonify(result)
@app.route("/metrics", methods=["GET"])
def metrics_route():
try:
return jsonify(get_metrics())
except Exception as exc:
return jsonify({"error": f"Failed to load metrics: {exc}"}), 500
# ping for cron job
@app.route("/ping")
def ping():
return {"status": "ok"}, 200
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|