histlearn commited on
Commit
dd755d8
·
verified ·
1 Parent(s): 73a0f89

feat: adiciona quickstart.ipynb (habilita badges Colab/Kaggle que estavam vazios)

Browse files
Files changed (1) hide show
  1. examples/quickstart.ipynb +172 -0
examples/quickstart.ipynb ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Quickstart — Community Notes Reranker (PT-BR)\n",
8
+ "\n",
9
+ "Notebook mínimo de inferência. Baixa o modelo (base Qwen3-Reranker-0.6B + adapter LoRA), monta o template e devolve a probabilidade de utilidade para um par `(tweet, nota)`.\n",
10
+ "\n",
11
+ "- **Runtime sugerido:** GPU (T4 basta). CPU também funciona, mas ~5-10s por inferência.\n",
12
+ "- **Modos disponíveis:** *fold único* (rápido) e *ensemble dos 5 folds* (reproduz exatamente o número reportado no model card).\n"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "metadata": {},
18
+ "source": [
19
+ "# Instala dependencias se necessario\n",
20
+ "import sys, subprocess, importlib\n",
21
+ "for mod, pkg in [(\"torch\",\"torch\"), (\"transformers\",\"transformers\"),\n",
22
+ " (\"peft\",\"peft\"), (\"huggingface_hub\",\"huggingface_hub\")]:\n",
23
+ " try: importlib.import_module(mod)\n",
24
+ " except Exception:\n",
25
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", pkg], check=True)\n"
26
+ ],
27
+ "outputs": [],
28
+ "execution_count": null
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "metadata": {},
33
+ "source": [
34
+ "## Carrega base + um fold (modo rápido)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "metadata": {},
40
+ "source": [
41
+ "import json, torch\n",
42
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
43
+ "from peft import PeftModel\n",
44
+ "from huggingface_hub import snapshot_download\n",
45
+ "\n",
46
+ "REPO = \"histlearn/community-notes-reranker-ptbr\"\n",
47
+ "path = snapshot_download(REPO, allow_patterns=[\"manifesto.json\", \"adapter_fold_1/*\"])\n",
48
+ "m = json.load(open(f\"{path}/manifesto.json\"))\n",
49
+ "\n",
50
+ "tok = AutoTokenizer.from_pretrained(m[\"base_model\"], padding_side=\"left\")\n",
51
+ "dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
52
+ "base = AutoModelForCausalLM.from_pretrained(m[\"base_model\"], torch_dtype=dtype)\n",
53
+ "model = PeftModel.from_pretrained(base, f\"{path}/adapter_fold_1\")\n",
54
+ "if torch.cuda.is_available(): model.cuda()\n",
55
+ "model.eval()\n",
56
+ "print(f\"Modelo pronto em: {model.device}\")\n"
57
+ ],
58
+ "outputs": [],
59
+ "execution_count": null
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "## Função de inferência"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "metadata": {},
71
+ "source": [
72
+ "def util_prob(tweet: str, nota: str) -> float:\n",
73
+ " \"\"\"Probabilidade de que a comunidade marcaria a nota como util.\n",
74
+ " Threshold otimo medido sob CV (Platt scaling) = 0.38.\"\"\"\n",
75
+ " text = (m[\"prompt_prefixo\"] + \"<Instruct>: \" + m[\"instrucao\"] +\n",
76
+ " \"\\n<Query>: \" + tweet + \"\\n<Document>: \" + nota + m[\"prompt_sufixo\"])\n",
77
+ " enc = tok(text, return_tensors=\"pt\", truncation=True, max_length=m[\"max_length\"]).to(model.device)\n",
78
+ " with torch.no_grad():\n",
79
+ " logits = model(**enc).logits[:, -1, :]\n",
80
+ " return float(torch.sigmoid(logits[:, m[\"id_yes\"]] - logits[:, m[\"id_no\"]]).item())\n"
81
+ ],
82
+ "outputs": [],
83
+ "execution_count": null
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "## Exemplo"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "metadata": {},
95
+ "source": [
96
+ "tweet = (\"Lula anunciou que o salario minimo subira para R$ 5 mil em 2026.\")\n",
97
+ "nota = (\"E falso. Em 12/12/2024, o presidente Lula anunciou que o salario minimo \"\n",
98
+ " \"subiria de R$ 1.412 para R$ 1.518 a partir de janeiro de 2025, segundo a \"\n",
99
+ " \"Agencia Brasil (https://agenciabrasil.ebc.com.br). Nao ha qualquer \"\n",
100
+ " \"anuncio oficial de valor proximo a R$ 5 mil.\")\n",
101
+ "\n",
102
+ "p = util_prob(tweet, nota)\n",
103
+ "print(f\"P(util) = {p:.4f}\")\n",
104
+ "print(f\"Classificacao (threshold 0.38): {'UTIL' if p >= 0.38 else 'NAO-UTIL'}\")\n"
105
+ ],
106
+ "outputs": [],
107
+ "execution_count": null
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "## Ensemble dos 5 folds (reproduz o número do model card)\n",
114
+ "\n",
115
+ "Para resultados estatisticamente comparáveis aos reportados (macro-F1 0.7920), use a média das probabilidades dos 5 adapters."
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "metadata": {},
121
+ "source": [
122
+ "from huggingface_hub import snapshot_download\n",
123
+ "path_full = snapshot_download(REPO, allow_patterns=[\"manifesto.json\", \"adapter_fold_*/*\"])\n",
124
+ "\n",
125
+ "def util_prob_ensemble(tweet: str, nota: str) -> float:\n",
126
+ " probs = []\n",
127
+ " for k in range(1, 6):\n",
128
+ " m_k = PeftModel.from_pretrained(base, f\"{path_full}/adapter_fold_{k}\")\n",
129
+ " m_k.eval()\n",
130
+ " text = (m[\"prompt_prefixo\"] + \"<Instruct>: \" + m[\"instrucao\"] +\n",
131
+ " \"\\n<Query>: \" + tweet + \"\\n<Document>: \" + nota + m[\"prompt_sufixo\"])\n",
132
+ " enc = tok(text, return_tensors=\"pt\", truncation=True, max_length=m[\"max_length\"]).to(m_k.device)\n",
133
+ " with torch.no_grad():\n",
134
+ " l = m_k(**enc).logits[:, -1, :]\n",
135
+ " probs.append(float(torch.sigmoid(l[:, m[\"id_yes\"]] - l[:, m[\"id_no\"]]).item()))\n",
136
+ " # Libera o adapter k antes de carregar o k+1\n",
137
+ " if hasattr(m_k, \"unload\"):\n",
138
+ " m_k.unload()\n",
139
+ " return sum(probs) / 5\n",
140
+ "\n",
141
+ "p_ens = util_prob_ensemble(tweet, nota)\n",
142
+ "print(f\"P(util) ensemble = {p_ens:.4f}\")\n"
143
+ ],
144
+ "outputs": [],
145
+ "execution_count": null
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "metadata": {},
150
+ "source": [
151
+ "## Próximos passos\n",
152
+ "\n",
153
+ "- Para documentação completa, métricas e contexto do projeto, ver o [model card](https://huggingface.co/histlearn/community-notes-reranker-ptbr).\n",
154
+ "- Para reproduzir o treino fold-a-fold ou regenerar os artefatos, ver o notebook `02_pipeline_experimento.ipynb` no [Space do projeto](https://huggingface.co/spaces/histlearn/communitynotesbr).\n",
155
+ "- Para o dataset bruto, ver [`histlearn/notas-comunidade-ptbr`](https://huggingface.co/datasets/histlearn/notas-comunidade-ptbr).\n"
156
+ ]
157
+ }
158
+ ],
159
+ "metadata": {
160
+ "kernelspec": {
161
+ "display_name": "Python 3",
162
+ "language": "python",
163
+ "name": "python3"
164
+ },
165
+ "language_info": {
166
+ "name": "python",
167
+ "version": "3.11"
168
+ }
169
+ },
170
+ "nbformat": 4,
171
+ "nbformat_minor": 5
172
+ }