""" AIFinder Flask API Serves the trained sklearn ensemble via the AIFinder inference class. """ import os import re import joblib import numpy as np from sklearn.ensemble import RandomForestClassifier from flask import Flask, jsonify, request, send_from_directory, render_template from flask_cors import CORS from flask_limiter import Limiter from flask_limiter.util import get_remote_address from config import MODEL_DIR from inference import AIFinder app = Flask(__name__) CORS(app) limiter = Limiter(get_remote_address, app=app) finder: AIFinder | None = None community_finder: AIFinder | None = None using_community = False DEFAULT_TOP_N = 4 COMMUNITY_DIR = os.path.join(MODEL_DIR, "community") CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib") corrections: list[dict] = [] def load_models(): global finder, community_finder, corrections finder = AIFinder(model_dir=MODEL_DIR) os.makedirs(COMMUNITY_DIR, exist_ok=True) if os.path.exists(CORRECTIONS_FILE): corrections = joblib.load(CORRECTIONS_FILE) if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")): try: community_finder = AIFinder(model_dir=COMMUNITY_DIR) except Exception: community_finder = None def _active_finder(): return community_finder if using_community and community_finder else finder def _strip_think_tags(text): text = re.sub(r".*?", "", text, flags=re.DOTALL) return text.strip() @app.route("/") def index(): return render_template("index.html") @app.route("/api/classify", methods=["POST"]) @app.route("/v1/classify", methods=["POST"]) @limiter.limit("60/minute") def v1_classify(): data = request.get_json(silent=True) if not data or "text" not in data: return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400 raw_text = data["text"] text = _strip_think_tags(raw_text) af = _active_finder() top_n = min(data.get("top_n", DEFAULT_TOP_N), len(af.le.classes_)) if not isinstance(top_n, int) or top_n < 1: top_n = DEFAULT_TOP_N if len(text) < 20: return jsonify( { "error": "Text too short (minimum 20 characters after stripping think tags)." } ), 400 proba = af.predict_proba(text) sorted_providers = sorted(proba.items(), key=lambda x: x[1], reverse=True)[:top_n] top_providers = [ {"name": name, "confidence": round(float(conf * 100), 2)} for name, conf in sorted_providers ] return jsonify( { "provider": top_providers[0]["name"], "confidence": top_providers[0]["confidence"], "top_providers": top_providers, } ) @app.route("/api/correct", methods=["POST"]) def correct(): global community_finder data = request.get_json(silent=True) if not data or "text" not in data or "correct_provider" not in data: return jsonify({"error": "Need 'text' and 'correct_provider'."}), 400 provider = data["correct_provider"] if provider not in list(finder.le.classes_): return jsonify({"error": f"Unknown provider: {provider}"}), 400 text = _strip_think_tags(data["text"]) corrections.append({"text": text, "provider": provider}) texts = [c["text"] for c in corrections] providers = [c["provider"] for c in corrections] X = finder.pipeline.transform(texts) y = finder.le.transform(providers) rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) rf.fit(X, y) joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")) joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib")) joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib")) joblib.dump(corrections, CORRECTIONS_FILE) community_finder = AIFinder(model_dir=COMMUNITY_DIR) return jsonify({"status": "ok", "loss": 0.0, "corrections": len(corrections)}) @app.route("/api/save", methods=["POST"]) def save_model(): if community_finder is None: return jsonify({"error": "No community model trained yet."}), 400 filename = "community_rf_4provider.joblib" return jsonify({"status": "ok", "filename": filename}) @app.route("/api/toggle_community", methods=["POST"]) def toggle_community(): global using_community data = request.get_json(silent=True) or {} using_community = bool(data.get("enabled", not using_community)) return jsonify({"using_community": using_community, "available": community_finder is not None}) @app.route("/models/") def download_model(filename): if filename.startswith("community_"): return send_from_directory(COMMUNITY_DIR, filename.replace("community_", "", 1)) return send_from_directory(MODEL_DIR, filename) @app.route("/api/status", methods=["GET"]) def status(): af = _active_finder() return jsonify( { "loaded": af is not None, "device": "cpu", "providers": list(af.le.classes_) if af else [], "num_providers": len(af.le.classes_) if af else 0, "using_community": using_community, "community_available": community_finder is not None, "corrections_count": len(corrections), } ) @app.route("/api/providers", methods=["GET"]) def providers(): return jsonify( { "providers": list(finder.le.classes_) if finder else [], } ) if __name__ == "__main__": print("Loading models...") load_models() print( f"Ready on cpu — {len(finder.le.classes_)} providers: " f"{', '.join(finder.le.classes_)}" ) app.run(host="0.0.0.0", port=7860)