File size: 7,121 Bytes
c31002d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Treina o classificador local de Tratabilidade.

EstratΓ©gia:
  1. Carrega as 100 questΓ΅es curated (50 alta EE / 50 baixa EE)
  2. [--api] ObtΓ©m scores reais via Claude API (usa cache SQLite automaticamente)
  3. [padrΓ£o] Usa labels binΓ‘rias como proxy: 0 β†’ 0.05, 1 β†’ 0.80
  4. Treina Ridge regression sobre embeddings all-MiniLM-L6-v2
  5. Salva em data/models/tractability/classifier.pkl

Modos:
  python -m src.classifier.train_tractability          # labels binΓ‘rias (rΓ‘pido, ~30s)
  python -m src.classifier.train_tractability --api    # scores reais via API (~5 min, mais preciso)
  python -m src.classifier.train_tractability --eval   # avalia modelo jΓ‘ treinado
"""

from __future__ import annotations

import io
import os
import sys

sys.path.insert(0, os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")))

# ForΓ§a UTF-8 no stdout do Windows
if sys.stdout.encoding and sys.stdout.encoding.lower() != "utf-8":
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")


def _get_scores_from_cache(questions: list[str]) -> dict[str, float]:
    """Puxa scores jΓ‘ calculados do cache SQLite (gratuito)."""
    try:
        from src.db.historico import get_trat_cache

        result = {}
        for q in questions:
            cached = get_trat_cache(q)
            if cached:
                result[q] = cached["prob_tractable"] * cached["confidence"]
        return result
    except Exception:
        return {}


def _get_scores_from_api(questions: list[str]) -> dict[str, float]:
    """Gera scores via Claude API β€” usa cache automaticamente."""
    from src.ee.tratabilidade import tratabilidade

    result = {}
    for i, q in enumerate(questions, 1):
        print(f"  [{i:3d}/{len(questions)}] {q[:70]}")
        r = tratabilidade(q)
        result[q] = r["prob_tractable"] * r["confidence"]
    return result


def _evaluate(questions: list[str], targets: list[float]) -> None:
    """Avalia o modelo treinado contra os targets."""
    import numpy as np

    from src.classifier.tractability_local import predict_local

    preds = [predict_local(q)["prob_tractable"] for q in questions]
    errors = [abs(p - t) for p, t in zip(preds, targets)]
    mae = np.mean(errors)

    # AcurΓ‘cia binΓ‘ria (threshold 0.4)
    correct = sum((p >= 0.4) == (t >= 0.4) for p, t in zip(preds, targets))
    acc = correct / len(questions)

    print(f"\n  AvaliaΓ§Γ£o sobre {len(questions)} exemplos:")
    print(f"  MAE       : {mae:.4f}")
    print(f"  AcurΓ‘cia  : {acc:.1%}  (threshold 0.40)")
    print()

    # Casos com maior erro
    top_erros = sorted(zip(errors, questions, targets, preds), reverse=True)[:5]
    print("  Top-5 maiores erros:")
    for err, q, t, p in top_erros:
        print(f"  [{err:.3f}] target={t:.2f} pred={p:.2f} | {q[:65]}")


def main(use_api: bool = False, only_eval: bool = False) -> None:
    from src.classifier.tractability_local import is_trained
    from src.eval.curated import get_curated

    print("=" * 65)
    print("  Treino β€” Classificador Local de Tratabilidade")
    print("=" * 65)

    curated = get_curated()
    questions = [q.text for q in curated]
    labels = [q.label for q in curated]

    print(
        f"\n  {len(curated)} questΓ΅es carregadas  "
        f"({sum(labels)} alta EE / {len(labels)-sum(labels)} baixa EE)"
    )

    # ── Modo avaliaΓ§Γ£o ──────────────────────────────────────────────
    if only_eval:
        if not is_trained():
            print("  ⚠ Modelo não treinado. Execute sem --eval primeiro.")
            return
        targets = [0.05 if l == 0 else 0.80 for l in labels]
        _evaluate(questions, targets)
        return

    # ── Determina scores alvo ───────────────────────────────────────
    if use_api:
        print("\n  Consultando cache SQLite...")
        cached = _get_scores_from_cache(questions)
        print(f"  {len(cached)}/{len(questions)} jΓ‘ em cache.")

        pending = [q for q in questions if q not in cached]
        if pending:
            print(f"  Chamando API para {len(pending)} questΓ΅es pendentes...")
            api_scores = _get_scores_from_api(pending)
            cached.update(api_scores)

        target_scores = [
            cached.get(q, 0.05 if labels[i] == 0 else 0.80) for i, q in enumerate(questions)
        ]
        print(f"  {len(cached)}/{len(questions)} scores obtidos.")
    else:
        target_scores = [0.05 if l == 0 else 0.80 for l in labels]
        print("\n  Modo: labels binarias  (0=0.05, 1=0.80)")
        print("  Dica: adicione --api para scores reais (mais precisos).")

    # ── Treina ──────────────────────────────────────────────────────
    print()
    from src.classifier.tractability_local import train_and_save

    metrics = train_and_save(questions, target_scores, alpha=50.0)

    # ── Resultados ──────────────────────────────────────────────────
    print("\n  ─── MΓ©tricas de treino ────────────────────────────────")
    print(f"  Exemplos     : {metrics['n_train']}")
    print(f"  RΒ²           : {metrics['r2']:.4f}")
    print(f"  RMSE         : {metrics['rmse']:.4f}")
    if metrics["cv_rmse_mean"] is not None:
        print(f"  RMSE CV 5-fold: {metrics['cv_rmse_mean']:.4f} Β± {metrics['cv_rmse_std']:.4f}")
    print(f"  Modelo salvo : {metrics['model_path']}")
    print("  ───────────────────────────────────────────────────────")

    # ── Teste rΓ‘pido ────────────────────────────────────────────────
    print("\n  Teste rΓ‘pido (4 questΓ΅es):")
    from src.classifier.tractability_local import predict_local

    test_cases = [
        ("What are the molecular mechanisms of CRISPR-Cas9 DNA repair?", "β†’ ALTO"),
        ("What is the essence of life?", "β†’ BAIXO"),
        ("How does transformer architecture scaling affect performance?", "β†’ ALTO"),
        ("What is the true nature of consciousness?", "β†’ BAIXO"),
    ]
    for q, expected in test_cases:
        r = predict_local(q)
        score = r["prob_tractable"]
        label = "βœ…" if (score >= 0.4) == (expected == "β†’ ALTO") else "⚠️"
        print(f"  {label} [{score:.3f}] {expected} | {q[:60]}")

    # AvaliaΓ§Γ£o completa
    _evaluate(questions, target_scores)

    print("  βœ… tratabilidade.py usarΓ‘ o modelo local automaticamente.")
    print()


if __name__ == "__main__":
    args = sys.argv[1:]
    main(
        use_api="--api" in args,
        only_eval="--eval" in args,
    )