Spaces:
Running
Running
| """ | |
| AIFinder Flask API | |
| Serves the trained sklearn ensemble via the AIFinder inference class. | |
| """ | |
| import os | |
| import re | |
| import joblib | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| from flask import Flask, jsonify, request, send_from_directory, render_template | |
| from flask_cors import CORS | |
| from flask_limiter import Limiter | |
| from flask_limiter.util import get_remote_address | |
| from config import MODEL_DIR | |
| from inference import AIFinder | |
| app = Flask(__name__) | |
| CORS(app) | |
| limiter = Limiter(get_remote_address, app=app) | |
| finder: AIFinder | None = None | |
| community_finder: AIFinder | None = None | |
| using_community = False | |
| DEFAULT_TOP_N = 4 | |
| COMMUNITY_DIR = os.path.join(MODEL_DIR, "community") | |
| CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib") | |
| corrections: list[dict] = [] | |
| def load_models(): | |
| global finder, community_finder, corrections | |
| finder = AIFinder(model_dir=MODEL_DIR) | |
| os.makedirs(COMMUNITY_DIR, exist_ok=True) | |
| if os.path.exists(CORRECTIONS_FILE): | |
| corrections = joblib.load(CORRECTIONS_FILE) | |
| if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")): | |
| try: | |
| community_finder = AIFinder(model_dir=COMMUNITY_DIR) | |
| except Exception: | |
| community_finder = None | |
| def _active_finder(): | |
| return community_finder if using_community and community_finder else finder | |
| def _strip_think_tags(text): | |
| text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL) | |
| return text.strip() | |
| def index(): | |
| return render_template("index.html") | |
| def v1_classify(): | |
| data = request.get_json(silent=True) | |
| if not data or "text" not in data: | |
| return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400 | |
| raw_text = data["text"] | |
| text = _strip_think_tags(raw_text) | |
| af = _active_finder() | |
| top_n = min(data.get("top_n", DEFAULT_TOP_N), len(af.le.classes_)) | |
| if not isinstance(top_n, int) or top_n < 1: | |
| top_n = DEFAULT_TOP_N | |
| if len(text) < 20: | |
| return jsonify( | |
| { | |
| "error": "Text too short (minimum 20 characters after stripping think tags)." | |
| } | |
| ), 400 | |
| proba = af.predict_proba(text) | |
| sorted_providers = sorted(proba.items(), key=lambda x: x[1], reverse=True)[:top_n] | |
| top_providers = [ | |
| {"name": name, "confidence": round(float(conf * 100), 2)} | |
| for name, conf in sorted_providers | |
| ] | |
| return jsonify( | |
| { | |
| "provider": top_providers[0]["name"], | |
| "confidence": top_providers[0]["confidence"], | |
| "top_providers": top_providers, | |
| } | |
| ) | |
| def correct(): | |
| global community_finder | |
| data = request.get_json(silent=True) | |
| if not data or "text" not in data or "correct_provider" not in data: | |
| return jsonify({"error": "Need 'text' and 'correct_provider'."}), 400 | |
| provider = data["correct_provider"] | |
| if provider not in list(finder.le.classes_): | |
| return jsonify({"error": f"Unknown provider: {provider}"}), 400 | |
| text = _strip_think_tags(data["text"]) | |
| corrections.append({"text": text, "provider": provider}) | |
| texts = [c["text"] for c in corrections] | |
| providers = [c["provider"] for c in corrections] | |
| X = finder.pipeline.transform(texts) | |
| y = finder.le.transform(providers) | |
| rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) | |
| rf.fit(X, y) | |
| joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")) | |
| joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib")) | |
| joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib")) | |
| joblib.dump(corrections, CORRECTIONS_FILE) | |
| community_finder = AIFinder(model_dir=COMMUNITY_DIR) | |
| return jsonify({"status": "ok", "loss": 0.0, "corrections": len(corrections)}) | |
| def save_model(): | |
| if community_finder is None: | |
| return jsonify({"error": "No community model trained yet."}), 400 | |
| filename = "community_rf_4provider.joblib" | |
| return jsonify({"status": "ok", "filename": filename}) | |
| def toggle_community(): | |
| global using_community | |
| data = request.get_json(silent=True) or {} | |
| using_community = bool(data.get("enabled", not using_community)) | |
| return jsonify({"using_community": using_community, "available": community_finder is not None}) | |
| def download_model(filename): | |
| if filename.startswith("community_"): | |
| return send_from_directory(COMMUNITY_DIR, filename.replace("community_", "", 1)) | |
| return send_from_directory(MODEL_DIR, filename) | |
| def status(): | |
| af = _active_finder() | |
| return jsonify( | |
| { | |
| "loaded": af is not None, | |
| "device": "cpu", | |
| "providers": list(af.le.classes_) if af else [], | |
| "num_providers": len(af.le.classes_) if af else 0, | |
| "using_community": using_community, | |
| "community_available": community_finder is not None, | |
| "corrections_count": len(corrections), | |
| } | |
| ) | |
| def providers(): | |
| return jsonify( | |
| { | |
| "providers": list(finder.le.classes_) if finder else [], | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| print("Loading models...") | |
| load_models() | |
| print( | |
| f"Ready on cpu — {len(finder.le.classes_)} providers: " | |
| f"{', '.join(finder.le.classes_)}" | |
| ) | |
| app.run(host="0.0.0.0", port=7860) | |