File size: 6,008 Bytes
39acd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
AIBalance API β€” recommandation sobre par usage
Version legere : charge des fichiers pre-calcules (pas de parquet 4 Go).
Pre-requis : lancer export_for_api.py une fois en local.
Endpoint : POST /recommend  {prompt, mode, top_n}
"""

from flask import Flask, request, jsonify
from flask_cors import CORS
from sentence_transformers import SentenceTransformer
import pandas as pd, re, time, os, pickle, json, joblib

app = Flask(__name__)
CORS(app)

# ── Config ──
EMBEDDING_MODEL = "Lajavaness/bilingual-embedding-small"
DELTA = 0.03
MEDIAN_ENERGY = 0.66
MODEL_PATH = "outputs_v3/api_clf.joblib"

# ── Globals ──
embedder = None
clf = None
reco_table = {}
profil_map = {}

FAMILY_MAP = {
    'Natural Science & Formal Science & Technology': 'code',
    'Education': 'education', 'Arts & Culture': 'creation', 'Arts': 'creation',
    'Entertainment & Travel & Hobby': 'creation', 'Food & Drink & Cooking': 'creation',
}


def get_task_family(cat):
    return FAMILY_MAP.get(cat, 'analysis' if 'Business' in cat or 'Law' in cat or 'Politic' in cat else 'education' if 'Education' in cat else 'default')


def extract_complexity(prompt):
    p = str(prompt)
    return {
        'n_words': len(p.split()),
        'is_code': int(bool(re.search(r'\b(code|python|java|sql|html|css|script|function|class|def |import)\b', p, re.I))),
        'n_questions': p.count('?'),
    }


def boot():
    """Charge les fichiers pre-calcules + embedder + classifieur."""
    global embedder, clf, reco_table, profil_map

    print("Boot API (mode leger)...")
    t0 = time.time()

    # embedder
    device = 'cpu'
    print(f"  Device : {device}")
    embedder = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True, device=device)

    # classifieur
    clf_obj = joblib.load(MODEL_PATH)
    print(f"  Classifieur charge depuis {MODEL_PATH}")

    # reco_table pre-calculee
    with open('outputs_v3/api_reco_table.pkl', 'rb') as f:
        reco_table.update(pickle.load(f))
    print(f"  Reco table : {len(reco_table)} categories")

    # profil map
    with open('outputs_v3/api_profil_map.json', 'r') as f:
        profil_map.update(json.load(f))

    print(f"Boot OK ({time.time() - t0:.1f}s)")
    return clf_obj


def classify_prompt(prompt):
    emb = embedder.encode([prompt], convert_to_numpy=True, normalize_embeddings=True)
    return clf.predict(emb)[0]


def recommander(prompt, top_n=3, mode='balanced'):
    MODES = {'performance': (0.8, 0.2), 'green': (0.2, 0.8), 'balanced': (0.5, 0.5)}
    alpha, beta = MODES.get(mode, (0.5, 0.5))

    category = classify_prompt(prompt)
    family = get_task_family(category)
    cx = extract_complexity(prompt)

    if category in reco_table:
        candidates = reco_table[category].copy()
        qual_col = 'combined_quality'
    else:
        found = False
        for orig in reco_table:
            if get_task_family(orig) == get_task_family(category):
                candidates = reco_table[orig].copy()
                qual_col = 'combined_quality'
                found = True
                break
        if not found:
            candidates = pd.DataFrame()
            qual_col = 'win_rate'

    if len(candidates) == 0:
        return {'error': 'Pas de candidats pour cette categorie', 'category': category}

    # score mode
    q = candidates[qual_col]
    q_norm = (q - q.min()) / (q.max() - q.min()) if q.max() > q.min() else 0.5
    e = candidates['wh_per_1k_tok']
    s_norm = 1 - (e - e.min()) / (e.max() - e.min()) if e.max() > e.min() else 0.5
    candidates['mode_score'] = alpha * q_norm + beta * s_norm

    # filtre DELTA
    best_q = candidates[qual_col].max()
    equivalent = candidates[candidates[qual_col] >= best_q - DELTA]

    # filtre complexite
    is_complex = cx['n_words'] > 100 or cx['is_code'] or cx['n_questions'] > 2
    if is_complex:
        wr_col = 'win_rate_cat' if 'win_rate_cat' in equivalent.columns else 'win_rate'
        median_wr = candidates[wr_col].median()
        complex_ok = equivalent[equivalent[wr_col] >= median_wr]
        if len(complex_ok) > 0:
            equivalent = complex_ok

    top = equivalent.nlargest(top_n, 'mode_score')

    # resultats
    results = []
    for _, r in top.iterrows():
        wr = r.get('win_rate_cat', r.get('win_rate', 0))
        cq = r.get(qual_col, wr)
        energy = r['wh_per_1k_tok']
        gain_vs_median = (1 - energy / MEDIAN_ENERGY) * 100 if MEDIAN_ENERGY > energy else 0
        results.append({
            'model': r['model'],
            'energy_wh': round(energy, 3),
            'win_rate': round(float(wr) * 100, 1),
            'quality': round(float(cq), 3),
            'profil': profil_map.get(r['model'], '?'),
            'gain_vs_median': round(gain_vs_median, 0),
        })

    return {
        'category': category,
        'family': family,
        'complexity': 'complexe' if is_complex else 'simple',
        'n_words': cx['n_words'],
        'mode': mode,
        'alpha': alpha,
        'beta': beta,
        'n_candidates': len(candidates),
        'n_equivalent': len(equivalent),
        'delta': DELTA,
        'recommendations': results,
    }


@app.route('/recommend', methods=['POST'])
def api_recommend():
    data = request.get_json()
    prompt = data.get('prompt', '')
    mode = data.get('mode', 'balanced')
    top_n = data.get('top_n', 3)

    if not prompt.strip():
        return jsonify({'error': 'Prompt vide'}), 400

    result = recommander(prompt, top_n=top_n, mode=mode)
    return jsonify(result)


@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        'status': 'ok',
        'categories': len(reco_table),
        'models': sum(len(t) for t in reco_table.values()),
    })


if __name__ == '__main__':
    clf = boot()
    port = int(os.environ.get('PORT', 7860))
    print(f"\nAPI prete : http://localhost:{port}")
    print("  POST /recommend  {prompt, mode, top_n}")
    print("  GET  /health")
    app.run(host='0.0.0.0', port=port, debug=False)