| """ |
| AIBalance API — recommandation sobre par usage |
| Version legere : charge des fichiers pre-calcules (pas de parquet 4 Go). |
| Pre-requis : lancer export_for_api.py une fois en local. |
| Endpoint : POST /recommend {prompt, mode, top_n} |
| """ |
|
|
| from flask import Flask, request, jsonify |
| from flask_cors import CORS |
| from sentence_transformers import SentenceTransformer |
| import pandas as pd, re, time, os, pickle, json, joblib |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| EMBEDDING_MODEL = "Lajavaness/bilingual-embedding-small" |
| DELTA = 0.03 |
| MEDIAN_ENERGY = 0.66 |
| MODEL_PATH = "outputs_v3/api_clf.joblib" |
|
|
| |
| embedder = None |
| clf = None |
| reco_table = {} |
| profil_map = {} |
|
|
| FAMILY_MAP = { |
| 'Natural Science & Formal Science & Technology': 'code', |
| 'Education': 'education', 'Arts & Culture': 'creation', 'Arts': 'creation', |
| 'Entertainment & Travel & Hobby': 'creation', 'Food & Drink & Cooking': 'creation', |
| } |
|
|
|
|
| def get_task_family(cat): |
| return FAMILY_MAP.get(cat, 'analysis' if 'Business' in cat or 'Law' in cat or 'Politic' in cat else 'education' if 'Education' in cat else 'default') |
|
|
|
|
| def extract_complexity(prompt): |
| p = str(prompt) |
| return { |
| 'n_words': len(p.split()), |
| 'is_code': int(bool(re.search(r'\b(code|python|java|sql|html|css|script|function|class|def |import)\b', p, re.I))), |
| 'n_questions': p.count('?'), |
| } |
|
|
|
|
| def boot(): |
| """Charge les fichiers pre-calcules + embedder + classifieur.""" |
| global embedder, clf, reco_table, profil_map |
|
|
| print("Boot API (mode leger)...") |
| t0 = time.time() |
|
|
| |
| device = 'cpu' |
| print(f" Device : {device}") |
| embedder = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True, device=device) |
|
|
| |
| clf_obj = joblib.load(MODEL_PATH) |
| print(f" Classifieur charge depuis {MODEL_PATH}") |
|
|
| |
| with open('outputs_v3/api_reco_table.pkl', 'rb') as f: |
| reco_table.update(pickle.load(f)) |
| print(f" Reco table : {len(reco_table)} categories") |
|
|
| |
| with open('outputs_v3/api_profil_map.json', 'r') as f: |
| profil_map.update(json.load(f)) |
|
|
| print(f"Boot OK ({time.time() - t0:.1f}s)") |
| return clf_obj |
|
|
|
|
| def classify_prompt(prompt): |
| emb = embedder.encode([prompt], convert_to_numpy=True, normalize_embeddings=True) |
| return clf.predict(emb)[0] |
|
|
|
|
| def recommander(prompt, top_n=3, mode='balanced'): |
| MODES = {'performance': (0.8, 0.2), 'green': (0.2, 0.8), 'balanced': (0.5, 0.5)} |
| alpha, beta = MODES.get(mode, (0.5, 0.5)) |
|
|
| category = classify_prompt(prompt) |
| family = get_task_family(category) |
| cx = extract_complexity(prompt) |
|
|
| if category in reco_table: |
| candidates = reco_table[category].copy() |
| qual_col = 'combined_quality' |
| else: |
| found = False |
| for orig in reco_table: |
| if get_task_family(orig) == get_task_family(category): |
| candidates = reco_table[orig].copy() |
| qual_col = 'combined_quality' |
| found = True |
| break |
| if not found: |
| candidates = pd.DataFrame() |
| qual_col = 'win_rate' |
|
|
| if len(candidates) == 0: |
| return {'error': 'Pas de candidats pour cette categorie', 'category': category} |
|
|
| |
| q = candidates[qual_col] |
| q_norm = (q - q.min()) / (q.max() - q.min()) if q.max() > q.min() else 0.5 |
| e = candidates['wh_per_1k_tok'] |
| s_norm = 1 - (e - e.min()) / (e.max() - e.min()) if e.max() > e.min() else 0.5 |
| candidates['mode_score'] = alpha * q_norm + beta * s_norm |
|
|
| |
| best_q = candidates[qual_col].max() |
| equivalent = candidates[candidates[qual_col] >= best_q - DELTA] |
|
|
| |
| is_complex = cx['n_words'] > 100 or cx['is_code'] or cx['n_questions'] > 2 |
| if is_complex: |
| wr_col = 'win_rate_cat' if 'win_rate_cat' in equivalent.columns else 'win_rate' |
| median_wr = candidates[wr_col].median() |
| complex_ok = equivalent[equivalent[wr_col] >= median_wr] |
| if len(complex_ok) > 0: |
| equivalent = complex_ok |
|
|
| top = equivalent.nlargest(top_n, 'mode_score') |
|
|
| |
| results = [] |
| for _, r in top.iterrows(): |
| wr = r.get('win_rate_cat', r.get('win_rate', 0)) |
| cq = r.get(qual_col, wr) |
| energy = r['wh_per_1k_tok'] |
| gain_vs_median = (1 - energy / MEDIAN_ENERGY) * 100 if MEDIAN_ENERGY > energy else 0 |
| results.append({ |
| 'model': r['model'], |
| 'energy_wh': round(energy, 3), |
| 'win_rate': round(float(wr) * 100, 1), |
| 'quality': round(float(cq), 3), |
| 'profil': profil_map.get(r['model'], '?'), |
| 'gain_vs_median': round(gain_vs_median, 0), |
| }) |
|
|
| return { |
| 'category': category, |
| 'family': family, |
| 'complexity': 'complexe' if is_complex else 'simple', |
| 'n_words': cx['n_words'], |
| 'mode': mode, |
| 'alpha': alpha, |
| 'beta': beta, |
| 'n_candidates': len(candidates), |
| 'n_equivalent': len(equivalent), |
| 'delta': DELTA, |
| 'recommendations': results, |
| } |
|
|
|
|
| @app.route('/recommend', methods=['POST']) |
| def api_recommend(): |
| data = request.get_json() |
| prompt = data.get('prompt', '') |
| mode = data.get('mode', 'balanced') |
| top_n = data.get('top_n', 3) |
|
|
| if not prompt.strip(): |
| return jsonify({'error': 'Prompt vide'}), 400 |
|
|
| result = recommander(prompt, top_n=top_n, mode=mode) |
| return jsonify(result) |
|
|
|
|
| @app.route('/health', methods=['GET']) |
| def health(): |
| return jsonify({ |
| 'status': 'ok', |
| 'categories': len(reco_table), |
| 'models': sum(len(t) for t in reco_table.values()), |
| }) |
|
|
|
|
| if __name__ == '__main__': |
| clf = boot() |
| port = int(os.environ.get('PORT', 7860)) |
| print(f"\nAPI prete : http://localhost:{port}") |
| print(" POST /recommend {prompt, mode, top_n}") |
| print(" GET /health") |
| app.run(host='0.0.0.0', port=port, debug=False) |
|
|