Spaces:
Running
Running
| """ | |
| Treina o classificador local de Tratabilidade. | |
| EstratΓ©gia: | |
| 1. Carrega as 100 questΓ΅es curated (50 alta EE / 50 baixa EE) | |
| 2. [--api] ObtΓ©m scores reais via Claude API (usa cache SQLite automaticamente) | |
| 3. [padrΓ£o] Usa labels binΓ‘rias como proxy: 0 β 0.05, 1 β 0.80 | |
| 4. Treina Ridge regression sobre embeddings all-MiniLM-L6-v2 | |
| 5. Salva em data/models/tractability/classifier.pkl | |
| Modos: | |
| python -m src.classifier.train_tractability # labels binΓ‘rias (rΓ‘pido, ~30s) | |
| python -m src.classifier.train_tractability --api # scores reais via API (~5 min, mais preciso) | |
| python -m src.classifier.train_tractability --eval # avalia modelo jΓ‘ treinado | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))) | |
| # ForΓ§a UTF-8 no stdout do Windows | |
| if sys.stdout.encoding and sys.stdout.encoding.lower() != "utf-8": | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") | |
| def _get_scores_from_cache(questions: list[str]) -> dict[str, float]: | |
| """Puxa scores jΓ‘ calculados do cache SQLite (gratuito).""" | |
| try: | |
| from src.db.historico import get_trat_cache | |
| result = {} | |
| for q in questions: | |
| cached = get_trat_cache(q) | |
| if cached: | |
| result[q] = cached["prob_tractable"] * cached["confidence"] | |
| return result | |
| except Exception: | |
| return {} | |
| def _get_scores_from_api(questions: list[str]) -> dict[str, float]: | |
| """Gera scores via Claude API β usa cache automaticamente.""" | |
| from src.ee.tratabilidade import tratabilidade | |
| result = {} | |
| for i, q in enumerate(questions, 1): | |
| print(f" [{i:3d}/{len(questions)}] {q[:70]}") | |
| r = tratabilidade(q) | |
| result[q] = r["prob_tractable"] * r["confidence"] | |
| return result | |
| def _evaluate(questions: list[str], targets: list[float]) -> None: | |
| """Avalia o modelo treinado contra os targets.""" | |
| import numpy as np | |
| from src.classifier.tractability_local import predict_local | |
| preds = [predict_local(q)["prob_tractable"] for q in questions] | |
| errors = [abs(p - t) for p, t in zip(preds, targets)] | |
| mae = np.mean(errors) | |
| # AcurΓ‘cia binΓ‘ria (threshold 0.4) | |
| correct = sum((p >= 0.4) == (t >= 0.4) for p, t in zip(preds, targets)) | |
| acc = correct / len(questions) | |
| print(f"\n AvaliaΓ§Γ£o sobre {len(questions)} exemplos:") | |
| print(f" MAE : {mae:.4f}") | |
| print(f" AcurΓ‘cia : {acc:.1%} (threshold 0.40)") | |
| print() | |
| # Casos com maior erro | |
| top_erros = sorted(zip(errors, questions, targets, preds), reverse=True)[:5] | |
| print(" Top-5 maiores erros:") | |
| for err, q, t, p in top_erros: | |
| print(f" [{err:.3f}] target={t:.2f} pred={p:.2f} | {q[:65]}") | |
| def main(use_api: bool = False, only_eval: bool = False) -> None: | |
| from src.classifier.tractability_local import is_trained | |
| from src.eval.curated import get_curated | |
| print("=" * 65) | |
| print(" Treino β Classificador Local de Tratabilidade") | |
| print("=" * 65) | |
| curated = get_curated() | |
| questions = [q.text for q in curated] | |
| labels = [q.label for q in curated] | |
| print( | |
| f"\n {len(curated)} questΓ΅es carregadas " | |
| f"({sum(labels)} alta EE / {len(labels)-sum(labels)} baixa EE)" | |
| ) | |
| # ββ Modo avaliaΓ§Γ£o ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if only_eval: | |
| if not is_trained(): | |
| print(" β Modelo nΓ£o treinado. Execute sem --eval primeiro.") | |
| return | |
| targets = [0.05 if l == 0 else 0.80 for l in labels] | |
| _evaluate(questions, targets) | |
| return | |
| # ββ Determina scores alvo βββββββββββββββββββββββββββββββββββββββ | |
| if use_api: | |
| print("\n Consultando cache SQLite...") | |
| cached = _get_scores_from_cache(questions) | |
| print(f" {len(cached)}/{len(questions)} jΓ‘ em cache.") | |
| pending = [q for q in questions if q not in cached] | |
| if pending: | |
| print(f" Chamando API para {len(pending)} questΓ΅es pendentes...") | |
| api_scores = _get_scores_from_api(pending) | |
| cached.update(api_scores) | |
| target_scores = [ | |
| cached.get(q, 0.05 if labels[i] == 0 else 0.80) for i, q in enumerate(questions) | |
| ] | |
| print(f" {len(cached)}/{len(questions)} scores obtidos.") | |
| else: | |
| target_scores = [0.05 if l == 0 else 0.80 for l in labels] | |
| print("\n Modo: labels binarias (0=0.05, 1=0.80)") | |
| print(" Dica: adicione --api para scores reais (mais precisos).") | |
| # ββ Treina ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print() | |
| from src.classifier.tractability_local import train_and_save | |
| metrics = train_and_save(questions, target_scores, alpha=50.0) | |
| # ββ Resultados ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n βββ MΓ©tricas de treino ββββββββββββββββββββββββββββββββ") | |
| print(f" Exemplos : {metrics['n_train']}") | |
| print(f" RΒ² : {metrics['r2']:.4f}") | |
| print(f" RMSE : {metrics['rmse']:.4f}") | |
| if metrics["cv_rmse_mean"] is not None: | |
| print(f" RMSE CV 5-fold: {metrics['cv_rmse_mean']:.4f} Β± {metrics['cv_rmse_std']:.4f}") | |
| print(f" Modelo salvo : {metrics['model_path']}") | |
| print(" βββββββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| # ββ Teste rΓ‘pido ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n Teste rΓ‘pido (4 questΓ΅es):") | |
| from src.classifier.tractability_local import predict_local | |
| test_cases = [ | |
| ("What are the molecular mechanisms of CRISPR-Cas9 DNA repair?", "β ALTO"), | |
| ("What is the essence of life?", "β BAIXO"), | |
| ("How does transformer architecture scaling affect performance?", "β ALTO"), | |
| ("What is the true nature of consciousness?", "β BAIXO"), | |
| ] | |
| for q, expected in test_cases: | |
| r = predict_local(q) | |
| score = r["prob_tractable"] | |
| label = "β " if (score >= 0.4) == (expected == "β ALTO") else "β οΈ" | |
| print(f" {label} [{score:.3f}] {expected} | {q[:60]}") | |
| # AvaliaΓ§Γ£o completa | |
| _evaluate(questions, target_scores) | |
| print(" β tratabilidade.py usarΓ‘ o modelo local automaticamente.") | |
| print() | |
| if __name__ == "__main__": | |
| args = sys.argv[1:] | |
| main( | |
| use_api="--api" in args, | |
| only_eval="--eval" in args, | |
| ) | |