Spaces:
Sleeping
Sleeping
File size: 7,121 Bytes
c31002d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Treina o classificador local de Tratabilidade.
EstratΓ©gia:
1. Carrega as 100 questΓ΅es curated (50 alta EE / 50 baixa EE)
2. [--api] ObtΓ©m scores reais via Claude API (usa cache SQLite automaticamente)
3. [padrΓ£o] Usa labels binΓ‘rias como proxy: 0 β 0.05, 1 β 0.80
4. Treina Ridge regression sobre embeddings all-MiniLM-L6-v2
5. Salva em data/models/tractability/classifier.pkl
Modos:
python -m src.classifier.train_tractability # labels binΓ‘rias (rΓ‘pido, ~30s)
python -m src.classifier.train_tractability --api # scores reais via API (~5 min, mais preciso)
python -m src.classifier.train_tractability --eval # avalia modelo jΓ‘ treinado
"""
from __future__ import annotations
import io
import os
import sys
sys.path.insert(0, os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")))
# ForΓ§a UTF-8 no stdout do Windows
if sys.stdout.encoding and sys.stdout.encoding.lower() != "utf-8":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
def _get_scores_from_cache(questions: list[str]) -> dict[str, float]:
"""Puxa scores jΓ‘ calculados do cache SQLite (gratuito)."""
try:
from src.db.historico import get_trat_cache
result = {}
for q in questions:
cached = get_trat_cache(q)
if cached:
result[q] = cached["prob_tractable"] * cached["confidence"]
return result
except Exception:
return {}
def _get_scores_from_api(questions: list[str]) -> dict[str, float]:
"""Gera scores via Claude API β usa cache automaticamente."""
from src.ee.tratabilidade import tratabilidade
result = {}
for i, q in enumerate(questions, 1):
print(f" [{i:3d}/{len(questions)}] {q[:70]}")
r = tratabilidade(q)
result[q] = r["prob_tractable"] * r["confidence"]
return result
def _evaluate(questions: list[str], targets: list[float]) -> None:
"""Avalia o modelo treinado contra os targets."""
import numpy as np
from src.classifier.tractability_local import predict_local
preds = [predict_local(q)["prob_tractable"] for q in questions]
errors = [abs(p - t) for p, t in zip(preds, targets)]
mae = np.mean(errors)
# AcurΓ‘cia binΓ‘ria (threshold 0.4)
correct = sum((p >= 0.4) == (t >= 0.4) for p, t in zip(preds, targets))
acc = correct / len(questions)
print(f"\n AvaliaΓ§Γ£o sobre {len(questions)} exemplos:")
print(f" MAE : {mae:.4f}")
print(f" AcurΓ‘cia : {acc:.1%} (threshold 0.40)")
print()
# Casos com maior erro
top_erros = sorted(zip(errors, questions, targets, preds), reverse=True)[:5]
print(" Top-5 maiores erros:")
for err, q, t, p in top_erros:
print(f" [{err:.3f}] target={t:.2f} pred={p:.2f} | {q[:65]}")
def main(use_api: bool = False, only_eval: bool = False) -> None:
from src.classifier.tractability_local import is_trained
from src.eval.curated import get_curated
print("=" * 65)
print(" Treino β Classificador Local de Tratabilidade")
print("=" * 65)
curated = get_curated()
questions = [q.text for q in curated]
labels = [q.label for q in curated]
print(
f"\n {len(curated)} questΓ΅es carregadas "
f"({sum(labels)} alta EE / {len(labels)-sum(labels)} baixa EE)"
)
# ββ Modo avaliaΓ§Γ£o ββββββββββββββββββββββββββββββββββββββββββββββ
if only_eval:
if not is_trained():
print(" β Modelo nΓ£o treinado. Execute sem --eval primeiro.")
return
targets = [0.05 if l == 0 else 0.80 for l in labels]
_evaluate(questions, targets)
return
# ββ Determina scores alvo βββββββββββββββββββββββββββββββββββββββ
if use_api:
print("\n Consultando cache SQLite...")
cached = _get_scores_from_cache(questions)
print(f" {len(cached)}/{len(questions)} jΓ‘ em cache.")
pending = [q for q in questions if q not in cached]
if pending:
print(f" Chamando API para {len(pending)} questΓ΅es pendentes...")
api_scores = _get_scores_from_api(pending)
cached.update(api_scores)
target_scores = [
cached.get(q, 0.05 if labels[i] == 0 else 0.80) for i, q in enumerate(questions)
]
print(f" {len(cached)}/{len(questions)} scores obtidos.")
else:
target_scores = [0.05 if l == 0 else 0.80 for l in labels]
print("\n Modo: labels binarias (0=0.05, 1=0.80)")
print(" Dica: adicione --api para scores reais (mais precisos).")
# ββ Treina ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print()
from src.classifier.tractability_local import train_and_save
metrics = train_and_save(questions, target_scores, alpha=50.0)
# ββ Resultados ββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n βββ MΓ©tricas de treino ββββββββββββββββββββββββββββββββ")
print(f" Exemplos : {metrics['n_train']}")
print(f" RΒ² : {metrics['r2']:.4f}")
print(f" RMSE : {metrics['rmse']:.4f}")
if metrics["cv_rmse_mean"] is not None:
print(f" RMSE CV 5-fold: {metrics['cv_rmse_mean']:.4f} Β± {metrics['cv_rmse_std']:.4f}")
print(f" Modelo salvo : {metrics['model_path']}")
print(" βββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
# ββ Teste rΓ‘pido ββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n Teste rΓ‘pido (4 questΓ΅es):")
from src.classifier.tractability_local import predict_local
test_cases = [
("What are the molecular mechanisms of CRISPR-Cas9 DNA repair?", "β ALTO"),
("What is the essence of life?", "β BAIXO"),
("How does transformer architecture scaling affect performance?", "β ALTO"),
("What is the true nature of consciousness?", "β BAIXO"),
]
for q, expected in test_cases:
r = predict_local(q)
score = r["prob_tractable"]
label = "β
" if (score >= 0.4) == (expected == "β ALTO") else "β οΈ"
print(f" {label} [{score:.3f}] {expected} | {q[:60]}")
# AvaliaΓ§Γ£o completa
_evaluate(questions, target_scores)
print(" β
tratabilidade.py usarΓ‘ o modelo local automaticamente.")
print()
if __name__ == "__main__":
args = sys.argv[1:]
main(
use_api="--api" in args,
only_eval="--eval" in args,
)
|