Spaces:
Sleeping
Sleeping
| """ | |
| Fase 3 — Inferencia guiada por EE (best-of-N sampling). | |
| Estrategia: | |
| 1. Policy model gera N candidatos para q_bad | |
| 2. Cada candidato passa pelo Filtro Stage 1: EE(q_cand) > EE(q_bad) + epsilon | |
| 3. Entre os aprovados, seleciona o de maior score(alpha) | |
| 4. Fallback: se nenhum aprovado, retorna o de maior EE entre todos | |
| Policy backends suportados: | |
| - hf_inference : HuggingFace Inference API (gratuita, zero-cost) — PADRAO | |
| - gguf : modelo GGUF local via llama-cpp-python (opcional) | |
| - claude : Claude API (requer ANTHROPIC_API_KEY) | |
| - local : modelo PEFT local (DPO checkpoint) com transformers + PEFT | |
| Traducao pt-br: | |
| - Padrao: Helsinki-NLP MarianMT local (zero-cost, ~300 MB por modelo) | |
| - Fallback: Claude API (se ANTHROPIC_API_KEY definida) | |
| Otimizacoes (Onda 1): | |
| - Geracao paralela: 8 chamadas simultaneas via ThreadPoolExecutor | |
| - Scoring paralelo: 8 scores EE simultaneos via ThreadPoolExecutor | |
| - Prompt caching: cache_control ephemeral (Claude API) | |
| - Cache de tratabilidade: in-memory + SQLite | |
| Configuracao via .env: | |
| INFERENCE_BACKEND = hf_inference | gguf | claude | local (default: auto) | |
| INFERENCE_MODEL_DIR = data/models/dpo_policy/tier3/final | |
| DPO_MODEL = gpt2 (base model para backend local) | |
| INFERENCE_N = 8 (candidatos por query) | |
| INFERENCE_ALPHA = 0.5 (peso EE vs proximidade no ranking) | |
| INFERENCE_MAX_NEW_TOKENS = 80 | |
| INFERENCE_TEMPERATURE = 1.1 | |
| INFERENCE_TOP_P = 0.95 | |
| TRANSLATE_BACKEND = local | claude (default: local) | |
| Uso (ingles): | |
| .venv\\Scripts\\python -m src.rl.inference "What is the essence of life?" | |
| .venv\\Scripts\\python -m src.rl.inference --batch caminho/para/perguntas.txt | |
| .venv\\Scripts\\python -m src.rl.inference --demo | |
| Uso (portugues — traducao automatica): | |
| .venv\\Scripts\\python -m src.rl.inference --pt "O que e a consciencia?" | |
| .venv\\Scripts\\python -m src.rl.inference --pt --batch caminho/para/perguntas_pt.txt | |
| .venv\\Scripts\\python -m src.rl.inference --pt --demo | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from concurrent.futures import ThreadPoolExecutor | |
| from concurrent.futures import as_completed | |
| from dataclasses import dataclass | |
| from dataclasses import field | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| # --------------------------------------------------------------------------- | |
| # Configuracao | |
| # --------------------------------------------------------------------------- | |
| BACKEND = os.getenv("INFERENCE_BACKEND", "auto") | |
| TRANSLATE_BACKEND = os.getenv("TRANSLATE_BACKEND", "local") | |
| MODEL_DIR = Path(os.getenv("INFERENCE_MODEL_DIR", "data/models/dpo_policy/tier3/final")) | |
| BASE_MODEL = os.getenv("DPO_MODEL", "gpt2") | |
| def _env_int(key: str, default: int) -> int: | |
| try: | |
| return int(os.getenv(key, str(default))) | |
| except ValueError: | |
| return default | |
| def _env_float(key: str, default: float) -> float: | |
| try: | |
| return float(os.getenv(key, str(default))) | |
| except ValueError: | |
| return default | |
| N_CANDIDATES = _env_int("INFERENCE_N", 8) | |
| ALPHA = _env_float("INFERENCE_ALPHA", 0.5) | |
| MAX_NEW_TOKENS = _env_int("INFERENCE_MAX_NEW_TOKENS", 80) | |
| TEMPERATURE = _env_float("INFERENCE_TEMPERATURE", 1.1) | |
| TOP_P = _env_float("INFERENCE_TOP_P", 0.95) | |
| CORPUS_DIR = Path(os.getenv("CORPUS_DIR", "data/corpus")) | |
| PROMPT_TEMPLATE = ( | |
| "You are an expert in philosophy of science. " | |
| "Reformulate the following research question to make it more epistemically tractable: " | |
| "operationalizable, methodologically grounded, and answerable with existing tools.\n\n" | |
| "Original question: {q_bad}\n\n" | |
| "Reformulated question:" | |
| ) | |
| # System prompts extraídos como constantes para reutilização com cache_control | |
| _GENERATION_SYSTEM = ( | |
| "You are an expert in philosophy of science. " | |
| "Your task is to reformulate research questions to make them more epistemically tractable: " | |
| "operationalizable, methodologically grounded, and answerable with existing tools. " | |
| "The user's question is enclosed in <question> tags. " | |
| "Respond with ONLY the reformulated question — no explanation, no preamble, no tags." | |
| ) | |
| _TRANSLATE_SYSTEMS = { | |
| "pt_to_en": ( | |
| "Translate the research question enclosed in <question> tags from Portuguese to English. " | |
| "Preserve the exact meaning and academic tone. " | |
| "Respond with ONLY the translated question, nothing else." | |
| ), | |
| "en_to_pt": ( | |
| "Translate the research question enclosed in <question> tags from English to Portuguese (Brazilian). " | |
| "Preserve the exact meaning and academic tone. " | |
| "Respond with ONLY the translated question, nothing else." | |
| ), | |
| } | |
| # Cliente Anthropic compartilhado (thread-safe) | |
| _claude_client = None | |
| def _get_claude_client(): | |
| global _claude_client | |
| if _claude_client is None: | |
| import anthropic | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "ANTHROPIC_API_KEY nao definido no .env. " | |
| "Configure-o ou use INFERENCE_BACKEND=local." | |
| ) | |
| _claude_client = anthropic.Anthropic(api_key=api_key) | |
| return _claude_client | |
| def pr(text: str) -> None: | |
| try: | |
| print(text) | |
| except UnicodeEncodeError: | |
| sys.stdout.buffer.write((text + "\n").encode("utf-8", errors="replace")) | |
| # --------------------------------------------------------------------------- | |
| # Resultado de inferencia | |
| # --------------------------------------------------------------------------- | |
| class InferenceResult: | |
| q_bad: str | |
| best: str | |
| ee_bad: float | |
| ee_best: float | |
| score_best: float | |
| stage1_pass: bool | |
| candidates: list[dict] = field(default_factory=list) | |
| def summary(self) -> str: | |
| lines = [ | |
| f" Input : {self.q_bad[:80]}", | |
| f" Output : {self.best[:80]}", | |
| f" EE : {self.ee_bad:.3f} -> {self.ee_best:.3f} " | |
| f"({'PASS' if self.stage1_pass else 'FALLBACK'})", | |
| f" Score : {self.score_best:.3f} (alpha={ALPHA})", | |
| f" N cand : {len(self.candidates)}", | |
| ] | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Corpus index (lazy) | |
| # --------------------------------------------------------------------------- | |
| _corpus_index = None | |
| def _get_index(): | |
| global _corpus_index | |
| if _corpus_index is None: | |
| from src.corpus.index import build_index | |
| try: | |
| _corpus_index = build_index(CORPUS_DIR) | |
| except FileNotFoundError: | |
| pr(" [aviso] Corpus nao encontrado — respondibilidade sera 0.") | |
| _corpus_index = _NullIndex() | |
| return _corpus_index | |
| class _NullIndex: | |
| """Fallback quando corpus nao esta disponivel.""" | |
| def search(self, *args, **kwargs): | |
| return [] | |
| # --------------------------------------------------------------------------- | |
| # EE scoring | |
| # --------------------------------------------------------------------------- | |
| def _score_candidate(q_cand: str, q_bad: str) -> dict: | |
| """Retorna dict com ee, score, prox para um candidato.""" | |
| from src.ee.reward import compute_ee | |
| from src.ee.reward import compute_score | |
| index = _get_index() | |
| try: | |
| result = compute_ee(q_cand, q_bad, index) | |
| score = compute_score(result, alpha=ALPHA) | |
| return { | |
| "text": q_cand, | |
| "ee": result.ee, | |
| "score": score, | |
| "prox": result.prox, | |
| "resp": result.respondibilidade, | |
| "tract": result.tratabilidade, | |
| "nt": result.nao_trivialidade, | |
| } | |
| except Exception as exc: | |
| return { | |
| "text": q_cand, | |
| "ee": 0.0, | |
| "score": 0.0, | |
| "prox": 0.0, | |
| "resp": 0.0, | |
| "tract": 0.0, | |
| "nt": 0.0, | |
| "error": str(exc), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Backend LOCAL (PEFT + transformers) | |
| # --------------------------------------------------------------------------- | |
| _local_pipeline = None | |
| def _load_local_pipeline(): | |
| global _local_pipeline | |
| if _local_pipeline is not None: | |
| return _local_pipeline | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| from transformers import AutoTokenizer | |
| from transformers import pipeline | |
| pr(f" Carregando modelo local: {MODEL_DIR}") | |
| # Tenta carregar PEFT adapter; fallback para base model | |
| if (MODEL_DIR / "adapter_config.json").exists(): | |
| from peft import PeftModel | |
| tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR)) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| model = PeftModel.from_pretrained(base, str(MODEL_DIR)) | |
| model = model.merge_and_unload() | |
| pr(" PEFT adapter carregado e mesclado.") | |
| elif MODEL_DIR.exists(): | |
| tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR)) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(MODEL_DIR), | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| pr(" Modelo completo carregado.") | |
| else: | |
| pr(f" [aviso] MODEL_DIR nao encontrado: {MODEL_DIR.name}") | |
| pr(f" Usando modelo base: {BASE_MODEL}") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| device = 0 if torch.cuda.is_available() else -1 | |
| _local_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| ) | |
| return _local_pipeline | |
| def _generate_local(q_bad: str, n: int) -> list[str]: | |
| """Gera n candidatos usando modelo local.""" | |
| from transformers import GenerationConfig | |
| pipe = _load_local_pipeline() | |
| prompt = PROMPT_TEMPLATE.format(q_bad=q_bad) | |
| gen_config = GenerationConfig( | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| num_return_sequences=n, | |
| pad_token_id=pipe.tokenizer.eos_token_id, | |
| ) | |
| outputs = pipe(prompt, generation_config=gen_config) | |
| candidates = [] | |
| for out in outputs: | |
| text = out["generated_text"] | |
| # Remove o prompt, ficando apenas a reformulacao | |
| if "Reformulated question:" in text: | |
| text = text.split("Reformulated question:")[-1] | |
| text = text.strip().split("\n")[0].strip() | |
| if text: | |
| candidates.append(text) | |
| return candidates | |
| # --------------------------------------------------------------------------- | |
| # Backend CLAUDE (Anthropic API) | |
| # --------------------------------------------------------------------------- | |
| def _generate_claude(q_bad: str, n: int) -> list[str]: | |
| """ | |
| Gera n candidatos usando Claude API em paralelo (ThreadPoolExecutor). | |
| Usa prompt caching no system prompt para reduzir custo. | |
| """ | |
| client = _get_claude_client() | |
| system = [ | |
| { | |
| "type": "text", | |
| "text": _GENERATION_SYSTEM, | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| user_msg = [{"role": "user", "content": f"<question>{q_bad}</question>"}] | |
| def _single_call(_): | |
| msg = client.messages.create( | |
| model="claude-haiku-4-5", | |
| max_tokens=100, | |
| temperature=1.0, | |
| system=system, | |
| messages=user_msg, | |
| ) | |
| return msg.content[0].text.strip().split("\n")[0].strip() | |
| candidates = [] | |
| with ThreadPoolExecutor(max_workers=n) as ex: | |
| futures = [ex.submit(_single_call, i) for i in range(n)] | |
| for f in as_completed(futures): | |
| try: | |
| text = f.result() | |
| if text: | |
| candidates.append(text) | |
| except Exception as exc: | |
| pr(f" [aviso] Erro na API Claude: {exc}") | |
| return candidates | |
| # --------------------------------------------------------------------------- | |
| # Interface principal | |
| # --------------------------------------------------------------------------- | |
| def _score_all(candidates_text: list[str], q_bad: str) -> list[dict]: | |
| """ | |
| Pontua todos os candidatos em paralelo via ThreadPoolExecutor. | |
| Cada score chama tratabilidade() (API Claude), que usa cache interno. | |
| """ | |
| with ThreadPoolExecutor(max_workers=len(candidates_text)) as ex: | |
| futures = {ex.submit(_score_candidate, c, q_bad): c for c in candidates_text} | |
| results = [] | |
| for f in as_completed(futures): | |
| try: | |
| results.append(f.result()) | |
| except Exception as exc: | |
| q = futures[f] | |
| pr(f" [aviso] Erro ao pontuar candidato: {exc}") | |
| results.append( | |
| { | |
| "text": q, | |
| "ee": 0.0, | |
| "score": 0.0, | |
| "prox": 0.0, | |
| "resp": 0.0, | |
| "tract": 0.0, | |
| "nt": 0.0, | |
| } | |
| ) | |
| return results | |
| def gerar_candidatos(q_bad: str, n: int = N_CANDIDATES) -> list[str]: | |
| """ | |
| Gera n reformulacoes candidatas para q_bad. | |
| Hierarquia de backends: | |
| auto/hf_inference/gguf → generate_free.generate() (zero-cost) | |
| claude → _generate_claude() (requer API key) | |
| local → _generate_local() (modelo PEFT local) | |
| """ | |
| if BACKEND == "claude": | |
| return _generate_claude(q_bad, n) | |
| elif BACKEND == "local": | |
| return _generate_local(q_bad, n) | |
| else: | |
| # auto, hf_inference, gguf — delega ao módulo zero-cost | |
| from src.rl.generate_free import generate as _generate_free | |
| return _generate_free(q_bad, n) | |
| def reformular(q_bad: str, n: int = N_CANDIDATES) -> InferenceResult: | |
| """ | |
| Pipeline completo: gera N candidatos, pontua, filtra e retorna o melhor. | |
| Args: | |
| q_bad: Pergunta de pesquisa original (pouco tratavel) | |
| n: Numero de candidatos a gerar | |
| Returns: | |
| InferenceResult com o melhor candidato e metricas | |
| """ | |
| from src.ee.reward import compute_ee | |
| # Score da pergunta original | |
| index = _get_index() | |
| r_bad = compute_ee(q_bad, q_bad, index) | |
| ee_bad = r_bad.ee | |
| # Gera candidatos | |
| candidates_text = gerar_candidatos(q_bad, n) | |
| if not candidates_text: | |
| # Fallback: devolve a pergunta original | |
| return InferenceResult( | |
| q_bad=q_bad, | |
| best=q_bad, | |
| ee_bad=ee_bad, | |
| ee_best=ee_bad, | |
| score_best=0.0, | |
| stage1_pass=False, | |
| candidates=[], | |
| ) | |
| # Pontua todos os candidatos em paralelo | |
| scored = _score_all(candidates_text, q_bad) | |
| # Filtro Stage 1: EE(cand) > EE(q_bad) + epsilon | |
| from src.ee.reward import _EPSILON | |
| approved = [s for s in scored if s["ee"] > ee_bad + _EPSILON] | |
| if approved: | |
| best = max(approved, key=lambda s: s["score"]) | |
| stage1_pass = True | |
| else: | |
| # Fallback: melhor EE entre todos | |
| best = max(scored, key=lambda s: s["ee"]) | |
| stage1_pass = False | |
| return InferenceResult( | |
| q_bad=q_bad, | |
| best=best["text"], | |
| ee_bad=ee_bad, | |
| ee_best=best["ee"], | |
| score_best=best["score"], | |
| stage1_pass=stage1_pass, | |
| candidates=scored, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Traducao automatica pt-br (MarianMT local por padrao; Claude como fallback) | |
| # --------------------------------------------------------------------------- | |
| def _translate_claude(text: str, direction: str) -> str: | |
| """Traduz via Claude API com prompt caching. direction: 'pt_to_en' | 'en_to_pt'.""" | |
| client = _get_claude_client() | |
| system = [ | |
| { | |
| "type": "text", | |
| "text": _TRANSLATE_SYSTEMS[direction], | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| msg = client.messages.create( | |
| model="claude-haiku-4-5", | |
| max_tokens=150, | |
| system=system, | |
| messages=[{"role": "user", "content": f"<question>{text}</question>"}], | |
| ) | |
| return msg.content[0].text.strip() | |
| def _translate(text: str, direction: str) -> str: | |
| """ | |
| Traduz texto pt-br <-> en. | |
| Usa MarianMT local por padrao (zero-cost); fallback para Claude API. | |
| direction: 'pt_to_en' | 'en_to_pt' | |
| """ | |
| if TRANSLATE_BACKEND == "claude": | |
| return _translate_claude(text, direction) | |
| # Backend local (padrao) | |
| try: | |
| from src.ee.translate_local import is_available | |
| from src.ee.translate_local import translate as _translate_local | |
| if is_available(): | |
| return _translate_local(text, direction) | |
| # transformers nao instalado — tenta Claude | |
| if os.getenv("ANTHROPIC_API_KEY"): | |
| pr(" [translate] transformers nao disponivel, usando Claude API...") | |
| return _translate_claude(text, direction) | |
| raise RuntimeError( | |
| "Nenhum backend de traducao disponivel. " | |
| "Instale transformers+sentencepiece ou defina ANTHROPIC_API_KEY." | |
| ) | |
| except ImportError: | |
| if os.getenv("ANTHROPIC_API_KEY"): | |
| pr(" [translate] translate_local nao encontrado, usando Claude API...") | |
| return _translate_claude(text, direction) | |
| raise | |
| class PtBrResult: | |
| q_bad_pt: str # pergunta original em pt-br | |
| q_bad_en: str # traducao para ingles | |
| best_en: str # melhor reformulacao em ingles | |
| best_pt: str # melhor reformulacao em pt-br | |
| ee_bad: float | |
| ee_best: float | |
| score_best: float | |
| stage1_pass: bool | |
| candidates: list[dict] = field(default_factory=list) | |
| def summary(self) -> str: | |
| lines = [ | |
| f" Entrada : {self.q_bad_pt[:80]}", | |
| f" (ingles) : {self.q_bad_en[:80]}", | |
| f" Resultado: {self.best_pt[:80]}", | |
| f" (ingles) : {self.best_en[:80]}", | |
| f" EE : {self.ee_bad:.3f} -> {self.ee_best:.3f} " | |
| f"({'PASS' if self.stage1_pass else 'FALLBACK'})", | |
| f" Score : {self.score_best:.3f} (alpha={ALPHA})", | |
| f" N cand : {len(self.candidates)}", | |
| ] | |
| return "\n".join(lines) | |
| def reformular_ptbr(q_bad_pt: str, n: int = N_CANDIDATES) -> PtBrResult: | |
| """ | |
| Pipeline completo com suporte a portugues: | |
| 1. Traduz q_bad_pt (pt-br) -> ingles | |
| 2. Roda reformular() em ingles | |
| 3. Traduz o melhor resultado de volta para pt-br | |
| Args: | |
| q_bad_pt: Pergunta em portugues (pt-br) | |
| n: Numero de candidatos a gerar | |
| Returns: | |
| PtBrResult com entrada e saida em pt-br e ingles | |
| """ | |
| pr(" [1/3] Traduzindo entrada (pt -> en)...") | |
| q_bad_en = _translate(q_bad_pt, "pt_to_en") | |
| pr(f" -> {q_bad_en[:80]}") | |
| pr(" [2/3] Reformulando (pipeline EE)...") | |
| result = reformular(q_bad_en, n) | |
| pr(" [3/3] Traduzindo resultado (en -> pt)...") | |
| best_pt = _translate(result.best, "en_to_pt") | |
| pr(f" -> {best_pt[:80]}") | |
| return PtBrResult( | |
| q_bad_pt=q_bad_pt, | |
| q_bad_en=q_bad_en, | |
| best_en=result.best, | |
| best_pt=best_pt, | |
| ee_bad=result.ee_bad, | |
| ee_best=result.ee_best, | |
| score_best=result.score_best, | |
| stage1_pass=result.stage1_pass, | |
| candidates=result.candidates, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Demo interativa | |
| # --------------------------------------------------------------------------- | |
| DEMO_QUESTIONS = [ | |
| "What is the meaning of life?", | |
| "What is consciousness?", | |
| "Does free will exist?", | |
| "What is the nature of time?", | |
| "Is there a theory of everything in physics?", | |
| ] | |
| DEMO_QUESTIONS_PT = [ | |
| "O que e a consciencia?", | |
| "O livre-arbitrio existe?", | |
| "Qual e a natureza do tempo?", | |
| "O que causa o envelhecimento biologico?", | |
| "Como surgiu a vida na Terra?", | |
| ] | |
| def run_demo() -> None: | |
| pr("=" * 65) | |
| pr(" Fase 3 — Demo Inferencia DPO (best-of-N + EE scoring)") | |
| pr("=" * 65) | |
| from src.rl.generate_free import _detect_backend as _det | |
| _resolved = _det() if BACKEND not in ("claude", "local") else BACKEND | |
| _blabel = f"{BACKEND} → {_resolved}" if BACKEND not in ("claude", "local") else BACKEND | |
| pr(f"\n Backend : {_blabel}") | |
| pr(f" N candid.: {N_CANDIDATES}") | |
| pr(f" Alpha : {ALPHA}") | |
| pr(f" Temp : {TEMPERATURE}") | |
| pr(f" Model dir: {MODEL_DIR.name}") | |
| pr("\n Carregando pipeline de scoring...") | |
| _get_index() # pre-carrega o corpus index | |
| for i, q in enumerate(DEMO_QUESTIONS, 1): | |
| pr(f"\n[{i}/{len(DEMO_QUESTIONS)}] Reformulando...") | |
| result = reformular(q) | |
| pr(result.summary()) | |
| # Top-3 candidatos | |
| sorted_cands = sorted(result.candidates, key=lambda s: s["score"], reverse=True) | |
| pr("\n Top-3 candidatos:") | |
| for j, c in enumerate(sorted_cands[:3], 1): | |
| pr(f" {j}. EE={c['ee']:.3f} | Score={c['score']:.3f} | {c['text'][:70]}") | |
| pr(f"\n{'='*65}") | |
| pr(" Demo concluida.") | |
| pr(f"{'='*65}") | |
| def run_demo_pt() -> None: | |
| pr("=" * 65) | |
| pr(" Fase 3 — Demo pt-br (traducao automatica + EE scoring)") | |
| pr("=" * 65) | |
| from src.rl.generate_free import _detect_backend as _det_pt | |
| _resolved_pt = _det_pt() if BACKEND not in ("claude", "local") else BACKEND | |
| _blabel_pt = f"{BACKEND} → {_resolved_pt}" if BACKEND not in ("claude", "local") else BACKEND | |
| pr(f"\n Backend : {_blabel_pt}") | |
| pr(f" N candid.: {N_CANDIDATES}") | |
| translate_info = "Claude Haiku" if TRANSLATE_BACKEND == "claude" else "MarianMT local" | |
| pr(f" Traducao : {translate_info} (pt <-> en)") | |
| _get_index() | |
| for i, q in enumerate(DEMO_QUESTIONS_PT, 1): | |
| pr(f"\n[{i}/{len(DEMO_QUESTIONS_PT)}] ----------------------------------------") | |
| result = reformular_ptbr(q) | |
| pr("\n" + result.summary()) | |
| pr(f"\n{'='*65}") | |
| pr(" Demo pt-br concluida.") | |
| pr(f"{'='*65}") | |
| def run_batch(path: str, ptbr: bool = False) -> None: | |
| """Processa um arquivo .txt com uma pergunta por linha.""" | |
| questions = [ | |
| l.strip() | |
| for l in Path(path).read_text(encoding="utf-8").splitlines() | |
| if l.strip() and not l.startswith("#") | |
| ] | |
| pr(f" Processando {len(questions)} perguntas de {path}...") | |
| results = [] | |
| for i, q in enumerate(questions, 1): | |
| pr(f"\n[{i}/{len(questions)}]") | |
| if ptbr: | |
| r = reformular_ptbr(q) | |
| pr(r.summary()) | |
| results.append( | |
| { | |
| "q_bad_pt": r.q_bad_pt, | |
| "q_bad_en": r.q_bad_en, | |
| "best_en": r.best_en, | |
| "best_pt": r.best_pt, | |
| "ee_bad": round(r.ee_bad, 4), | |
| "ee_best": round(r.ee_best, 4), | |
| "score_best": round(r.score_best, 4), | |
| "stage1_pass": r.stage1_pass, | |
| } | |
| ) | |
| else: | |
| r = reformular(q) | |
| pr(r.summary()) | |
| results.append( | |
| { | |
| "q_bad": r.q_bad, | |
| "best": r.best, | |
| "ee_bad": round(r.ee_bad, 4), | |
| "ee_best": round(r.ee_best, 4), | |
| "score_best": round(r.score_best, 4), | |
| "stage1_pass": r.stage1_pass, | |
| } | |
| ) | |
| out_path = Path(path).with_suffix(".results.jsonl") | |
| with out_path.open("w", encoding="utf-8") as f: | |
| for r in results: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| pr(f"\n Resultados salvos em: {out_path}") | |
| # --------------------------------------------------------------------------- | |
| # Entrypoint CLI | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| args = sys.argv[1:] | |
| ptbr = "--pt" in args | |
| if ptbr: | |
| args = [a for a in args if a != "--pt"] | |
| if "--demo" in args: | |
| if ptbr: | |
| run_demo_pt() | |
| else: | |
| run_demo() | |
| elif "--batch" in args: | |
| idx = args.index("--batch") | |
| if idx + 1 >= len(args): | |
| pr("Uso: python -m src.rl.inference [--pt] --batch caminho/para/arquivo.txt") | |
| sys.exit(1) | |
| run_batch(args[idx + 1], ptbr=ptbr) | |
| elif args and not args[0].startswith("--"): | |
| q = " ".join(args) | |
| pr(f"\n Entrada: {q}") | |
| if ptbr: | |
| result = reformular_ptbr(q) | |
| pr("\n" + result.summary()) | |
| pr("\n Top candidatos (decrescente por score):") | |
| for j, c in enumerate( | |
| sorted(result.candidates, key=lambda s: s["score"], reverse=True), 1 | |
| ): | |
| status = "PASS" if c["ee"] > result.ee_bad + 0.05 else "FAIL" | |
| pr(f" {j:2}. [{status}] EE={c['ee']:.3f} Sc={c['score']:.3f} | {c['text'][:75]}") | |
| else: | |
| result = reformular(q) | |
| pr(result.summary()) | |
| pr("\n Todos os candidatos (decrescente por score):") | |
| for j, c in enumerate( | |
| sorted(result.candidates, key=lambda s: s["score"], reverse=True), 1 | |
| ): | |
| status = "PASS" if c["ee"] > result.ee_bad + 0.05 else "FAIL" | |
| pr(f" {j:2}. [{status}] EE={c['ee']:.3f} Sc={c['score']:.3f} | {c['text'][:75]}") | |
| else: | |
| pr(__doc__) | |
| sys.exit(0) | |