histlearn's picture
Update app.py
0704c04 verified
"""Gradio app — endpoint de utilidade para community notes em PT-BR.
Expõe:
- UI web com três abas: Prever / Explicar / Sobre.
- API HTTP em /gradio_api/call/predict e /gradio_api/call/explain (gerada
automaticamente pelo Gradio a partir dos api_name).
Para clientes Python, use gradio_client:
from gradio_client import Client
c = Client("<user>/<space>", hf_token="hf_...")
score = c.predict("texto da nota...", api_name="/predict")
"""
from __future__ import annotations
import base64
import html
import logging
import os
import traceback
from pathlib import Path
import gradio as gr
from config import (
CONFIDENCE_BOUNDS_ALTA,
CONFIDENCE_BOUNDS_MEDIA,
THRESHOLD_UTIL,
)
from inference import DEVICE, explain_occlusion, predict_one, warmup
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
log = logging.getLogger("app")
# ---------------------------------------------------------------------------
# CSS do projeto
# ---------------------------------------------------------------------------
APP_DIR = Path(__file__).resolve().parent
STYLE_PATH = APP_DIR / "styles.css"
CUSTOM_CSS = STYLE_PATH.read_text(encoding="utf-8") if STYLE_PATH.exists() else ""
LOGO_PATH = APP_DIR / "logo_notas_svg_suavizado_transparente.svg"
_logo_src = ""
if LOGO_PATH.exists():
_logo_b64 = base64.b64encode(LOGO_PATH.read_bytes()).decode()
_logo_src = f"data:image/svg+xml;base64,{_logo_b64}"
# ---------------------------------------------------------------------------
# Warm-up agressivo — queremos que o primeiro request não pague cold-start
# ---------------------------------------------------------------------------
MODEL_READY: bool
MODEL_ERROR: str | None
try:
warmup()
MODEL_READY = True
MODEL_ERROR = None
log.info("Modelo carregado no startup. Device=%s", DEVICE)
except Exception as exc: # noqa: BLE001 — queremos pegar qualquer falha de carregamento
MODEL_READY = False
MODEL_ERROR = f"{type(exc).__name__}: {exc}"
log.error("Falha ao carregar modelo no startup:\n%s", traceback.format_exc())
# ---------------------------------------------------------------------------
# Helpers de apresentação
# ---------------------------------------------------------------------------
def _confidence_band(p: float) -> str:
lo_a, hi_a = CONFIDENCE_BOUNDS_ALTA
lo_m, hi_m = CONFIDENCE_BOUNDS_MEDIA
if p <= lo_a or p >= hi_a:
return "Alta"
if p <= lo_m or p >= hi_m:
return "Média"
return "Baixa"
def _label(p: float) -> str:
return "Útil" if p >= THRESHOLD_UTIL else "Não-útil"
def _score_card_html(p: float) -> str:
"""Card principal do resultado — usando classes CSS do projeto."""
lbl = _label(p)
band = _confidence_band(p)
lbl_class = "notinhas-badge-util" if lbl == "Útil" else "notinhas-badge-nao-util"
if band == "Alta":
band_class = lbl_class
elif band == "Média":
band_class = "notinhas-badge-media"
else:
band_class = "notinhas-badge-baixa"
return f"""
<div class="notinhas-card">
<div style="display:flex;justify-content:space-between;align-items:center;gap:12px;flex-wrap:wrap;">
<div style="display:flex;gap:8px;flex-wrap:wrap;">
<span class="notinhas-badge {lbl_class}">{lbl}</span>
<span class="notinhas-badge {band_class}">Confiança {band}</span>
</div>
<div style="text-align:right;">
<div class="notinhas-score-label">P(útil)</div>
<div class="notinhas-score-value">{p:.4f}</div>
</div>
</div>
</div>
"""
def _contrib_color(v: float, v_max: float) -> str:
if v_max <= 0:
return "transparent"
intensity = min(1.0, abs(v) / v_max)
alpha = 0.15 + 0.65 * intensity # 0.15 .. 0.80
if v > 0:
return f"rgba(95, 168, 143, {alpha:.3f})" # verde (PALETA['util'] do notebook)
return f"rgba(224, 123, 107, {alpha:.3f})" # coral (PALETA['nao_util'])
def _highlighted_text_html(tokens: list[str], contribs: list[float]) -> str:
if not tokens:
return "<em>(sem palavras para destacar)</em>"
v_max = max((abs(c) for c in contribs), default=1e-9) or 1e-9
spans = []
for tok, c in zip(tokens, contribs):
bg = _contrib_color(c, v_max)
spans.append(
f'<span style="background:{bg};padding:2px 4px;border-radius:4px;'
f'margin:0 1px;" title="Δ={c:+.6f}">{html.escape(tok)}</span>'
)
return (
'<div style="font-size:15px;line-height:2;color:#212529;'
'font-family:system-ui, -apple-system, sans-serif;padding:4px;">'
+ " ".join(spans)
+ "</div>"
)
def _top_tokens_table_html(
tokens: list[str], contribs: list[float], k: int = 5
) -> str:
pairs = list(zip(tokens, contribs))
pos = sorted([p for p in pairs if p[1] > 0], key=lambda x: -x[1])[:k]
neg = sorted([p for p in pairs if p[1] < 0], key=lambda x: x[1])[:k]
def _row(tok: str, v: float, side: str) -> str:
color = "#1b4332" if side == "pos" else "#9d0208"
return (
f'<tr><td style="padding:5px 8px;color:{color};">'
f"{html.escape(tok)}</td>"
f'<td style="padding:5px 8px;text-align:right;color:{color};'
f'font-variant-numeric:tabular-nums;">{v:+.6f}</td></tr>'
)
empty = '<tr><td colspan="2" style="padding:6px;color:#9aa1aa;"><em>—</em></td></tr>'
pos_rows = "".join(_row(t, v, "pos") for t, v in pos) or empty
neg_rows = "".join(_row(t, v, "neg") for t, v in neg) or empty
all_same_side = (not neg and pos) or (not pos and neg)
if not neg and pos:
side_warning = (
'<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
'⚠️ <strong>Nenhuma palavra puxando para não-útil identificada.</strong> '
'O método leave-one-out compara a frase completa com cada ablação de uma palavra. '
'Quando todas as contribuições são positivas, a frase completa pontua '
'marginalmente <em>mais</em> do que qualquer subconjunto — comum em textos '
'muito curtos ou frases com sentido idiomático. '
'O texto permanece Não-útil porque P(útil) está longe do limiar (0.5); '
'o que o define é a <em>ausência</em> de características úteis '
'(fontes, dados, neutralidade), não palavras negativas específicas.'
'</p>'
)
elif not pos and neg:
side_warning = (
'<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
'⚠️ <strong>Nenhuma palavra puxando para útil identificada.</strong> '
'Todas as palavras reduzem marginalmente P(útil) quando presentes.'
'</p>'
)
else:
side_warning = ""
return f"""
<div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;margin-top:12px;
font-family:system-ui, -apple-system, sans-serif;">
<div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
<div style="font-size:13px;font-weight:700;color:#1b4332;margin-bottom:6px;">
Empurram para útil
</div>
<table style="width:100%;border-collapse:collapse;font-size:13px;">{pos_rows}</table>
</div>
<div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
<div style="font-size:13px;font-weight:700;color:#9d0208;margin-bottom:6px;">
Empurram para não-útil
</div>
<table style="width:100%;border-collapse:collapse;font-size:13px;">{neg_rows}</table>
</div>
</div>
""" + side_warning
# ---------------------------------------------------------------------------
# Handlers — retornam HTML para a UI + JSON para a API
# ---------------------------------------------------------------------------
def handle_predict(text: str):
text = (text or "").strip()
if not text:
return "<em>Forneça um texto.</em>", {"error": "empty_input"}
if not MODEL_READY:
err = MODEL_ERROR or "modelo indisponível"
return (
f"<em>Modelo indisponível: {html.escape(err)}</em>",
{"error": "model_unavailable", "detail": err},
)
p = predict_one(text)
return (
_score_card_html(p),
{
"proba_util": p,
"label": _label(p),
"confidence_band": _confidence_band(p),
},
)
def handle_explain(text: str):
text = (text or "").strip()
if not text:
return "<em>Forneça um texto.</em>", "", "", {"error": "empty_input"}
if not MODEL_READY:
err = MODEL_ERROR or "modelo indisponível"
return (
f"<em>Modelo indisponível: {html.escape(err)}</em>",
"",
"",
{"error": "model_unavailable", "detail": err},
)
result = explain_occlusion(text)
p = result["proba_full"]
tokens = result["tokens"]
contribs = result["contributions"]
return (
_score_card_html(p),
_highlighted_text_html(tokens, contribs),
_top_tokens_table_html(tokens, contribs),
{
"proba_util": p,
"label": _label(p),
"confidence_band": _confidence_band(p),
"tokens": tokens,
"contributions": contribs,
},
)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
EXAMPLE_UTIL = (
"Segundo dados oficiais do Ministério da Saúde, o número citado no tweet é falso. "
"A fonte correta pode ser conferida no link: https://www.gov.br/saude/..."
)
EXAMPLE_NAO = "Essa nota é claramente desnecessária, é opinião pessoal do autor."
_APP_TITLE = "Notinhas — endpoint de utilidade"
INTRO_MD = """
Classificador de utilidade para **community notes em português**, baseado em
**bge-m3 (568M params) + LoRA + cabeça linear** (calibrado via Platt scaling).
- **Prever** — score + label + faixa de confiança.
- **Explicar** — o mesmo + contribuição de cada palavra via leave-one-out.
- **Sobre** — detalhes técnicos e limitações.
"""
with gr.Blocks(
title="Notinhas — endpoint de utilidade (FT-Solo)",
theme=gr.themes.Base(),
css=CUSTOM_CSS,
) as demo:
if _logo_src:
gr.HTML(
'<div style="display:flex;align-items:center;gap:20px;padding:12px 0 4px;">'
f'<img src="{_logo_src}" style="height:80px;width:auto;flex-shrink:0;" alt="Notinhas">'
f'<h1 style="margin:0;font-size:1.75em;font-weight:700;line-height:1.2;">{_APP_TITLE}</h1>'
'</div>'
)
else:
gr.Markdown(f"# {_APP_TITLE}")
gr.Markdown(INTRO_MD)
if not MODEL_READY:
gr.Markdown(
f"""
> ⚠️ **Modelo não carregou.** Detalhe: `{html.escape(MODEL_ERROR or '')}`
>
> Verifique que `artifacts/fold_01_adapter/` e `artifacts/fold_01_head.pt` estão presentes
> no repositório do Space. Se o modelo base exigir autenticação, configure `HF_TOKEN` em
> **Settings → Variables and secrets**.
"""
)
with gr.Tab("Prever"):
with gr.Row():
with gr.Column(scale=2):
inp_p = gr.Textbox(
label="Texto da nota",
placeholder="Cole aqui o texto em português...",
lines=7,
max_lines=25,
)
btn_p = gr.Button("Prever", variant="primary")
gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_p])
with gr.Column(scale=3):
out_card_p = gr.HTML(label="Resultado")
out_json_p = gr.JSON(label="Resposta da API")
btn_p.click(
handle_predict,
inputs=[inp_p],
outputs=[out_card_p, out_json_p],
api_name="predict",
)
with gr.Tab("Explicar"):
with gr.Row():
with gr.Column(scale=2):
inp_e = gr.Textbox(
label="Texto da nota",
placeholder="Cole aqui o texto em português...",
lines=7,
max_lines=25,
)
btn_e = gr.Button("Explicar", variant="primary")
gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_e])
with gr.Column(scale=3):
out_card_e = gr.HTML(label="Resultado")
out_hl = gr.HTML(label="Contribuição por palavra")
out_tbl = gr.HTML(label="Top tokens por lado")
out_json_e = gr.JSON(label="Resposta da API")
btn_e.click(
handle_explain,
inputs=[inp_e],
outputs=[out_card_e, out_hl, out_tbl, out_json_e],
api_name="explain",
)
with gr.Tab("Sobre"):
gr.Markdown(
f"""
### Detalhes técnicos
- **Modelo base**: `BAAI/bge-m3` (embedding, 1.024 dims, mean pooling, 568M params).
- **Adaptação**: LoRA treinado com alvo `label_binary_strict` (recorte A do projeto).
- **Fold servido**: `fold_04` (melhor fold segundo o manifesto do pipeline).
- **Cabeça**: `nn.Linear(1024, 1)` → sigmoid.
- **Calibração**: Platt scaling pós-treino — `P_calib = sigmoid(CALIB_A × logit(P_raw) + CALIB_B)`. Com os defaults `CALIB_A=1.0, CALIB_B=0.0` equivale a identidade; ajuste em `config.py` com base num conjunto de validação.
- **Prompt de instrução**: nenhum — texto cru (bge-m3 não usa prefix de instrução).
- **max_length**: 256 tokens.
- **Dispositivo atual**: `{DEVICE}`.
### Método de explicação
A aba **Explicar** usa **occlusion word-level** (leave-one-out): para cada palavra
separada por espaço, calculamos `Δ = P(texto completo) − P(texto sem a palavra)`.
- Δ positivo ⇒ palavra puxando para **útil** (verde).
- Δ negativo ⇒ palavra puxando para **não-útil** (coral).
É uma aproximação rápida do SHAP Partition usado no notebook de explicabilidade
(~1–2 s vs ~12–15 s em GPU), com resultados visualmente comparáveis para notas curtas.
### Limitações
- O rótulo `helpful` mede **aceitabilidade bipartidária**, não qualidade editorial.
A galeria curada do notebook mostra casos onde vizinhos semânticos idênticos
recebem rótulos opostos por razões políticas.
- Textos são truncados em 256 tokens.
- Este endpoint serve um único fold. Para produção com ganho marginal de robustez,
subir para ensemble dos 5 folds (média de probabilidades).
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
show_api=True,
)