reformulatee / src /classifier /train_tractability.py
fmrod
deploy: docs atualizadas
c31002d
"""
Treina o classificador local de Tratabilidade.
EstratΓ©gia:
1. Carrega as 100 questΓ΅es curated (50 alta EE / 50 baixa EE)
2. [--api] ObtΓ©m scores reais via Claude API (usa cache SQLite automaticamente)
3. [padrΓ£o] Usa labels binΓ‘rias como proxy: 0 β†’ 0.05, 1 β†’ 0.80
4. Treina Ridge regression sobre embeddings all-MiniLM-L6-v2
5. Salva em data/models/tractability/classifier.pkl
Modos:
python -m src.classifier.train_tractability # labels binΓ‘rias (rΓ‘pido, ~30s)
python -m src.classifier.train_tractability --api # scores reais via API (~5 min, mais preciso)
python -m src.classifier.train_tractability --eval # avalia modelo jΓ‘ treinado
"""
from __future__ import annotations
import io
import os
import sys
sys.path.insert(0, os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")))
# ForΓ§a UTF-8 no stdout do Windows
if sys.stdout.encoding and sys.stdout.encoding.lower() != "utf-8":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
def _get_scores_from_cache(questions: list[str]) -> dict[str, float]:
"""Puxa scores jΓ‘ calculados do cache SQLite (gratuito)."""
try:
from src.db.historico import get_trat_cache
result = {}
for q in questions:
cached = get_trat_cache(q)
if cached:
result[q] = cached["prob_tractable"] * cached["confidence"]
return result
except Exception:
return {}
def _get_scores_from_api(questions: list[str]) -> dict[str, float]:
"""Gera scores via Claude API β€” usa cache automaticamente."""
from src.ee.tratabilidade import tratabilidade
result = {}
for i, q in enumerate(questions, 1):
print(f" [{i:3d}/{len(questions)}] {q[:70]}")
r = tratabilidade(q)
result[q] = r["prob_tractable"] * r["confidence"]
return result
def _evaluate(questions: list[str], targets: list[float]) -> None:
"""Avalia o modelo treinado contra os targets."""
import numpy as np
from src.classifier.tractability_local import predict_local
preds = [predict_local(q)["prob_tractable"] for q in questions]
errors = [abs(p - t) for p, t in zip(preds, targets)]
mae = np.mean(errors)
# AcurΓ‘cia binΓ‘ria (threshold 0.4)
correct = sum((p >= 0.4) == (t >= 0.4) for p, t in zip(preds, targets))
acc = correct / len(questions)
print(f"\n AvaliaΓ§Γ£o sobre {len(questions)} exemplos:")
print(f" MAE : {mae:.4f}")
print(f" AcurΓ‘cia : {acc:.1%} (threshold 0.40)")
print()
# Casos com maior erro
top_erros = sorted(zip(errors, questions, targets, preds), reverse=True)[:5]
print(" Top-5 maiores erros:")
for err, q, t, p in top_erros:
print(f" [{err:.3f}] target={t:.2f} pred={p:.2f} | {q[:65]}")
def main(use_api: bool = False, only_eval: bool = False) -> None:
from src.classifier.tractability_local import is_trained
from src.eval.curated import get_curated
print("=" * 65)
print(" Treino β€” Classificador Local de Tratabilidade")
print("=" * 65)
curated = get_curated()
questions = [q.text for q in curated]
labels = [q.label for q in curated]
print(
f"\n {len(curated)} questΓ΅es carregadas "
f"({sum(labels)} alta EE / {len(labels)-sum(labels)} baixa EE)"
)
# ── Modo avaliaΓ§Γ£o ──────────────────────────────────────────────
if only_eval:
if not is_trained():
print(" ⚠ Modelo não treinado. Execute sem --eval primeiro.")
return
targets = [0.05 if l == 0 else 0.80 for l in labels]
_evaluate(questions, targets)
return
# ── Determina scores alvo ───────────────────────────────────────
if use_api:
print("\n Consultando cache SQLite...")
cached = _get_scores_from_cache(questions)
print(f" {len(cached)}/{len(questions)} jΓ‘ em cache.")
pending = [q for q in questions if q not in cached]
if pending:
print(f" Chamando API para {len(pending)} questΓ΅es pendentes...")
api_scores = _get_scores_from_api(pending)
cached.update(api_scores)
target_scores = [
cached.get(q, 0.05 if labels[i] == 0 else 0.80) for i, q in enumerate(questions)
]
print(f" {len(cached)}/{len(questions)} scores obtidos.")
else:
target_scores = [0.05 if l == 0 else 0.80 for l in labels]
print("\n Modo: labels binarias (0=0.05, 1=0.80)")
print(" Dica: adicione --api para scores reais (mais precisos).")
# ── Treina ──────────────────────────────────────────────────────
print()
from src.classifier.tractability_local import train_and_save
metrics = train_and_save(questions, target_scores, alpha=50.0)
# ── Resultados ──────────────────────────────────────────────────
print("\n ─── MΓ©tricas de treino ────────────────────────────────")
print(f" Exemplos : {metrics['n_train']}")
print(f" RΒ² : {metrics['r2']:.4f}")
print(f" RMSE : {metrics['rmse']:.4f}")
if metrics["cv_rmse_mean"] is not None:
print(f" RMSE CV 5-fold: {metrics['cv_rmse_mean']:.4f} Β± {metrics['cv_rmse_std']:.4f}")
print(f" Modelo salvo : {metrics['model_path']}")
print(" ───────────────────────────────────────────────────────")
# ── Teste rΓ‘pido ────────────────────────────────────────────────
print("\n Teste rΓ‘pido (4 questΓ΅es):")
from src.classifier.tractability_local import predict_local
test_cases = [
("What are the molecular mechanisms of CRISPR-Cas9 DNA repair?", "β†’ ALTO"),
("What is the essence of life?", "β†’ BAIXO"),
("How does transformer architecture scaling affect performance?", "β†’ ALTO"),
("What is the true nature of consciousness?", "β†’ BAIXO"),
]
for q, expected in test_cases:
r = predict_local(q)
score = r["prob_tractable"]
label = "βœ…" if (score >= 0.4) == (expected == "β†’ ALTO") else "⚠️"
print(f" {label} [{score:.3f}] {expected} | {q[:60]}")
# AvaliaΓ§Γ£o completa
_evaluate(questions, target_scores)
print(" βœ… tratabilidade.py usarΓ‘ o modelo local automaticamente.")
print()
if __name__ == "__main__":
args = sys.argv[1:]
main(
use_api="--api" in args,
only_eval="--eval" in args,
)