""" AIBalance API — recommandation sobre par usage Version legere : charge des fichiers pre-calcules (pas de parquet 4 Go). Pre-requis : lancer export_for_api.py une fois en local. Endpoint : POST /recommend {prompt, mode, top_n} """ from flask import Flask, request, jsonify from flask_cors import CORS from sentence_transformers import SentenceTransformer import pandas as pd, re, time, os, pickle, json, joblib app = Flask(__name__) CORS(app) # ── Config ── EMBEDDING_MODEL = "Lajavaness/bilingual-embedding-small" DELTA = 0.03 MEDIAN_ENERGY = 0.66 MODEL_PATH = "outputs_v3/api_clf.joblib" # ── Globals ── embedder = None clf = None reco_table = {} profil_map = {} FAMILY_MAP = { 'Natural Science & Formal Science & Technology': 'code', 'Education': 'education', 'Arts & Culture': 'creation', 'Arts': 'creation', 'Entertainment & Travel & Hobby': 'creation', 'Food & Drink & Cooking': 'creation', } def get_task_family(cat): return FAMILY_MAP.get(cat, 'analysis' if 'Business' in cat or 'Law' in cat or 'Politic' in cat else 'education' if 'Education' in cat else 'default') def extract_complexity(prompt): p = str(prompt) return { 'n_words': len(p.split()), 'is_code': int(bool(re.search(r'\b(code|python|java|sql|html|css|script|function|class|def |import)\b', p, re.I))), 'n_questions': p.count('?'), } def boot(): """Charge les fichiers pre-calcules + embedder + classifieur.""" global embedder, clf, reco_table, profil_map print("Boot API (mode leger)...") t0 = time.time() # embedder device = 'cpu' print(f" Device : {device}") embedder = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True, device=device) # classifieur clf_obj = joblib.load(MODEL_PATH) print(f" Classifieur charge depuis {MODEL_PATH}") # reco_table pre-calculee with open('outputs_v3/api_reco_table.pkl', 'rb') as f: reco_table.update(pickle.load(f)) print(f" Reco table : {len(reco_table)} categories") # profil map with open('outputs_v3/api_profil_map.json', 'r') as f: profil_map.update(json.load(f)) print(f"Boot OK ({time.time() - t0:.1f}s)") return clf_obj def classify_prompt(prompt): emb = embedder.encode([prompt], convert_to_numpy=True, normalize_embeddings=True) return clf.predict(emb)[0] def recommander(prompt, top_n=3, mode='balanced'): MODES = {'performance': (0.8, 0.2), 'green': (0.2, 0.8), 'balanced': (0.5, 0.5)} alpha, beta = MODES.get(mode, (0.5, 0.5)) category = classify_prompt(prompt) family = get_task_family(category) cx = extract_complexity(prompt) if category in reco_table: candidates = reco_table[category].copy() qual_col = 'combined_quality' else: found = False for orig in reco_table: if get_task_family(orig) == get_task_family(category): candidates = reco_table[orig].copy() qual_col = 'combined_quality' found = True break if not found: candidates = pd.DataFrame() qual_col = 'win_rate' if len(candidates) == 0: return {'error': 'Pas de candidats pour cette categorie', 'category': category} # score mode q = candidates[qual_col] q_norm = (q - q.min()) / (q.max() - q.min()) if q.max() > q.min() else 0.5 e = candidates['wh_per_1k_tok'] s_norm = 1 - (e - e.min()) / (e.max() - e.min()) if e.max() > e.min() else 0.5 candidates['mode_score'] = alpha * q_norm + beta * s_norm # filtre DELTA best_q = candidates[qual_col].max() equivalent = candidates[candidates[qual_col] >= best_q - DELTA] # filtre complexite is_complex = cx['n_words'] > 100 or cx['is_code'] or cx['n_questions'] > 2 if is_complex: wr_col = 'win_rate_cat' if 'win_rate_cat' in equivalent.columns else 'win_rate' median_wr = candidates[wr_col].median() complex_ok = equivalent[equivalent[wr_col] >= median_wr] if len(complex_ok) > 0: equivalent = complex_ok top = equivalent.nlargest(top_n, 'mode_score') # resultats results = [] for _, r in top.iterrows(): wr = r.get('win_rate_cat', r.get('win_rate', 0)) cq = r.get(qual_col, wr) energy = r['wh_per_1k_tok'] gain_vs_median = (1 - energy / MEDIAN_ENERGY) * 100 if MEDIAN_ENERGY > energy else 0 results.append({ 'model': r['model'], 'energy_wh': round(energy, 3), 'win_rate': round(float(wr) * 100, 1), 'quality': round(float(cq), 3), 'profil': profil_map.get(r['model'], '?'), 'gain_vs_median': round(gain_vs_median, 0), }) return { 'category': category, 'family': family, 'complexity': 'complexe' if is_complex else 'simple', 'n_words': cx['n_words'], 'mode': mode, 'alpha': alpha, 'beta': beta, 'n_candidates': len(candidates), 'n_equivalent': len(equivalent), 'delta': DELTA, 'recommendations': results, } @app.route('/recommend', methods=['POST']) def api_recommend(): data = request.get_json() prompt = data.get('prompt', '') mode = data.get('mode', 'balanced') top_n = data.get('top_n', 3) if not prompt.strip(): return jsonify({'error': 'Prompt vide'}), 400 result = recommander(prompt, top_n=top_n, mode=mode) return jsonify(result) @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'ok', 'categories': len(reco_table), 'models': sum(len(t) for t in reco_table.values()), }) if __name__ == '__main__': clf = boot() port = int(os.environ.get('PORT', 7860)) print(f"\nAPI prete : http://localhost:{port}") print(" POST /recommend {prompt, mode, top_n}") print(" GET /health") app.run(host='0.0.0.0', port=port, debug=False)