aibalance / api.py
Alexis Fabre
API AIBalance legere
39acd95
"""
AIBalance API — recommandation sobre par usage
Version legere : charge des fichiers pre-calcules (pas de parquet 4 Go).
Pre-requis : lancer export_for_api.py une fois en local.
Endpoint : POST /recommend {prompt, mode, top_n}
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
import pandas as pd, re, time, os, pickle, json, joblib
app = Flask(__name__)
CORS(app)
# ── Config ──
EMBEDDING_MODEL = "Lajavaness/bilingual-embedding-small"
DELTA = 0.03
MEDIAN_ENERGY = 0.66
MODEL_PATH = "outputs_v3/api_clf.joblib"
# ── Globals ──
embedder = None
clf = None
reco_table = {}
profil_map = {}
FAMILY_MAP = {
'Natural Science & Formal Science & Technology': 'code',
'Education': 'education', 'Arts & Culture': 'creation', 'Arts': 'creation',
'Entertainment & Travel & Hobby': 'creation', 'Food & Drink & Cooking': 'creation',
}
def get_task_family(cat):
return FAMILY_MAP.get(cat, 'analysis' if 'Business' in cat or 'Law' in cat or 'Politic' in cat else 'education' if 'Education' in cat else 'default')
def extract_complexity(prompt):
p = str(prompt)
return {
'n_words': len(p.split()),
'is_code': int(bool(re.search(r'\b(code|python|java|sql|html|css|script|function|class|def |import)\b', p, re.I))),
'n_questions': p.count('?'),
}
def boot():
"""Charge les fichiers pre-calcules + embedder + classifieur."""
global embedder, clf, reco_table, profil_map
print("Boot API (mode leger)...")
t0 = time.time()
# embedder
device = 'cpu'
print(f" Device : {device}")
embedder = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True, device=device)
# classifieur
clf_obj = joblib.load(MODEL_PATH)
print(f" Classifieur charge depuis {MODEL_PATH}")
# reco_table pre-calculee
with open('outputs_v3/api_reco_table.pkl', 'rb') as f:
reco_table.update(pickle.load(f))
print(f" Reco table : {len(reco_table)} categories")
# profil map
with open('outputs_v3/api_profil_map.json', 'r') as f:
profil_map.update(json.load(f))
print(f"Boot OK ({time.time() - t0:.1f}s)")
return clf_obj
def classify_prompt(prompt):
emb = embedder.encode([prompt], convert_to_numpy=True, normalize_embeddings=True)
return clf.predict(emb)[0]
def recommander(prompt, top_n=3, mode='balanced'):
MODES = {'performance': (0.8, 0.2), 'green': (0.2, 0.8), 'balanced': (0.5, 0.5)}
alpha, beta = MODES.get(mode, (0.5, 0.5))
category = classify_prompt(prompt)
family = get_task_family(category)
cx = extract_complexity(prompt)
if category in reco_table:
candidates = reco_table[category].copy()
qual_col = 'combined_quality'
else:
found = False
for orig in reco_table:
if get_task_family(orig) == get_task_family(category):
candidates = reco_table[orig].copy()
qual_col = 'combined_quality'
found = True
break
if not found:
candidates = pd.DataFrame()
qual_col = 'win_rate'
if len(candidates) == 0:
return {'error': 'Pas de candidats pour cette categorie', 'category': category}
# score mode
q = candidates[qual_col]
q_norm = (q - q.min()) / (q.max() - q.min()) if q.max() > q.min() else 0.5
e = candidates['wh_per_1k_tok']
s_norm = 1 - (e - e.min()) / (e.max() - e.min()) if e.max() > e.min() else 0.5
candidates['mode_score'] = alpha * q_norm + beta * s_norm
# filtre DELTA
best_q = candidates[qual_col].max()
equivalent = candidates[candidates[qual_col] >= best_q - DELTA]
# filtre complexite
is_complex = cx['n_words'] > 100 or cx['is_code'] or cx['n_questions'] > 2
if is_complex:
wr_col = 'win_rate_cat' if 'win_rate_cat' in equivalent.columns else 'win_rate'
median_wr = candidates[wr_col].median()
complex_ok = equivalent[equivalent[wr_col] >= median_wr]
if len(complex_ok) > 0:
equivalent = complex_ok
top = equivalent.nlargest(top_n, 'mode_score')
# resultats
results = []
for _, r in top.iterrows():
wr = r.get('win_rate_cat', r.get('win_rate', 0))
cq = r.get(qual_col, wr)
energy = r['wh_per_1k_tok']
gain_vs_median = (1 - energy / MEDIAN_ENERGY) * 100 if MEDIAN_ENERGY > energy else 0
results.append({
'model': r['model'],
'energy_wh': round(energy, 3),
'win_rate': round(float(wr) * 100, 1),
'quality': round(float(cq), 3),
'profil': profil_map.get(r['model'], '?'),
'gain_vs_median': round(gain_vs_median, 0),
})
return {
'category': category,
'family': family,
'complexity': 'complexe' if is_complex else 'simple',
'n_words': cx['n_words'],
'mode': mode,
'alpha': alpha,
'beta': beta,
'n_candidates': len(candidates),
'n_equivalent': len(equivalent),
'delta': DELTA,
'recommendations': results,
}
@app.route('/recommend', methods=['POST'])
def api_recommend():
data = request.get_json()
prompt = data.get('prompt', '')
mode = data.get('mode', 'balanced')
top_n = data.get('top_n', 3)
if not prompt.strip():
return jsonify({'error': 'Prompt vide'}), 400
result = recommander(prompt, top_n=top_n, mode=mode)
return jsonify(result)
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'categories': len(reco_table),
'models': sum(len(t) for t in reco_table.values()),
})
if __name__ == '__main__':
clf = boot()
port = int(os.environ.get('PORT', 7860))
print(f"\nAPI prete : http://localhost:{port}")
print(" POST /recommend {prompt, mode, top_n}")
print(" GET /health")
app.run(host='0.0.0.0', port=port, debug=False)