histlearn commited on
Commit
233b2df
·
verified ·
1 Parent(s): c13089e

feat: migrar para bge-m3 (568M params, mean pooling, fold 04→fold_01)

Browse files
README.md CHANGED
@@ -10,7 +10,7 @@ app_file: app.py
10
  pinned: false
11
  short_description: Classificador de utilidade para community notes em PT-BR.
12
  models:
13
- - Qwen/Qwen3-Embedding-4B
14
  ---
15
 
16
  # Notinhas — endpoint de utilidade (FT-Solo)
@@ -20,8 +20,8 @@ note* em português, devolve a probabilidade de ela ser classificada como "útil
20
  (`label_binary_strict = 1`), junto com uma leitura opcional da contribuição de
21
  cada palavra.
22
 
23
- Arquitetura: **Qwen3-Embedding-4B + LoRA + cabeça linear**, idêntica ao
24
- `predict_from_text` do notebook `explicabilidade_qwen4b_redesign` em modo fiel
25
  (fold 01).
26
 
27
  ## Estrutura do repositório
@@ -38,7 +38,7 @@ Arquitetura: **Qwen3-Embedding-4B + LoRA + cabeça linear**, idêntica ao
38
  ├── fold_01_adapter/ # Pasta do adapter LoRA
39
  │ ├── adapter_config.json
40
  │ └── adapter_model.safetensors
41
- └── fold_01_head.pt # State dict do nn.Linear(2560, 1)
42
  ```
43
 
44
  ## Setup — do zero até o Space no ar
@@ -51,13 +51,13 @@ Na UI do Hugging Face:
51
  2. SDK: **Gradio**.
52
  3. Hardware: **T4 small** (recomendado — caber na memória em bf16 e inferência
53
  em ~0,5 s). **A10G small** dá latência ainda menor. **ZeroGPU** funciona mas
54
- com cold-start mais longo. **CPU** roda, porém cada inferência leva 20–40 s.
55
  4. Visibility: **Private**.
56
 
57
  ### 2. Popular `artifacts/`
58
 
59
  Os pesos vêm do pipeline do projeto. O zip base do Drive (`artefatos_projeto.zip`)
60
- traz as pastas `qwen4b_adapters/` e `qwen4b_heads/`. Rode localmente:
61
 
62
  ```bash
63
  pip install gdown
@@ -91,7 +91,7 @@ git commit -m "feat: endpoint inicial FT-Solo"
91
  git push
92
  ```
93
 
94
- O adapter do Qwen3-Embedding-4B em LoRA costuma ficar entre **20 e 80 MB**
95
  (dependendo do rank e dos módulos-alvo). A cabeça é ~20 KB. Tudo cabe
96
  confortavelmente sem apertar quota.
97
 
@@ -99,7 +99,7 @@ confortavelmente sem apertar quota.
99
 
100
  Em **Settings → Variables and secrets**:
101
 
102
- - `HF_TOKEN` — só necessário se `Qwen/Qwen3-Embedding-4B` virar gated no futuro.
103
  Hoje o modelo é público, então você pode ignorar.
104
 
105
  ### 5. Primeiro boot
@@ -107,7 +107,7 @@ Em **Settings → Variables and secrets**:
107
  Na primeira inicialização o Space:
108
 
109
  1. Instala `requirements.txt` (~1 min).
110
- 2. Baixa `Qwen/Qwen3-Embedding-4B` da HF (~8 GB, ~23 min).
111
  3. Carrega adapter + head (~5 s).
112
  4. Fica pronto — e o warm-up do modelo já aconteceu, o primeiro request é rápido.
113
 
 
10
  pinned: false
11
  short_description: Classificador de utilidade para community notes em PT-BR.
12
  models:
13
+ - BAAI/bge-m3
14
  ---
15
 
16
  # Notinhas — endpoint de utilidade (FT-Solo)
 
20
  (`label_binary_strict = 1`), junto com uma leitura opcional da contribuição de
21
  cada palavra.
22
 
23
+ Arquitetura: **bge-m3 (568M params) + LoRA + cabeça linear**, idêntica ao
24
+ `predict_from_text` do notebook FT-Solo em modo fiel
25
  (fold 01).
26
 
27
  ## Estrutura do repositório
 
38
  ├── fold_01_adapter/ # Pasta do adapter LoRA
39
  │ ├── adapter_config.json
40
  │ └── adapter_model.safetensors
41
+ └── fold_01_head.pt # State dict do nn.Linear(1024, 1)
42
  ```
43
 
44
  ## Setup — do zero até o Space no ar
 
51
  2. SDK: **Gradio**.
52
  3. Hardware: **T4 small** (recomendado — caber na memória em bf16 e inferência
53
  em ~0,5 s). **A10G small** dá latência ainda menor. **ZeroGPU** funciona mas
54
+ com cold-start mais longo. **CPU** roda inferência ~4–8 s com bge-m3 (vs 20–40 s do Qwen3).
55
  4. Visibility: **Private**.
56
 
57
  ### 2. Popular `artifacts/`
58
 
59
  Os pesos vêm do pipeline do projeto. O zip base do Drive (`artefatos_projeto.zip`)
60
+ traz as pastas com adapters e heads bge-m3. Rode localmente:
61
 
62
  ```bash
63
  pip install gdown
 
91
  git push
92
  ```
93
 
94
+ O adapter bge-m3 em LoRA costuma ficar entre **20 e 60 MB**
95
  (dependendo do rank e dos módulos-alvo). A cabeça é ~20 KB. Tudo cabe
96
  confortavelmente sem apertar quota.
97
 
 
99
 
100
  Em **Settings → Variables and secrets**:
101
 
102
+ - `HF_TOKEN` — só necessário se `BAAI/bge-m3` virar gated no futuro.
103
  Hoje o modelo é público, então você pode ignorar.
104
 
105
  ### 5. Primeiro boot
 
107
  Na primeira inicialização o Space:
108
 
109
  1. Instala `requirements.txt` (~1 min).
110
+ 2. Baixa `BAAI/bge-m3` da HF (~2 GB, ~3060 s).
111
  3. Carrega adapter + head (~5 s).
112
  4. Fica pronto — e o warm-up do modelo já aconteceu, o primeiro request é rápido.
113
 
app.py CHANGED
@@ -1,390 +1,386 @@
1
- """Gradio app — endpoint de utilidade para community notes em PT-BR.
2
-
3
- Expõe:
4
- - UI web com três abas: Prever / Explicar / Sobre.
5
- - API HTTP em /gradio_api/call/predict e /gradio_api/call/explain (gerada
6
- automaticamente pelo Gradio a partir dos api_name).
7
-
8
- Para clientes Python, use gradio_client:
9
-
10
- from gradio_client import Client
11
- c = Client("<user>/<space>", hf_token="hf_...")
12
- score = c.predict("texto da nota...", api_name="/predict")
13
- """
14
- from __future__ import annotations
15
-
16
- import html
17
- import logging
18
- import os
19
- import traceback
20
- from pathlib import Path
21
-
22
- import gradio as gr
23
-
24
- from config import (
25
- CONFIDENCE_BOUNDS_ALTA,
26
- CONFIDENCE_BOUNDS_MEDIA,
27
- THRESHOLD_UTIL,
28
- )
29
- from inference import DEVICE, explain_occlusion, predict_one, warmup
30
-
31
- logging.basicConfig(
32
- level=logging.INFO,
33
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
34
- )
35
- log = logging.getLogger("app")
36
-
37
- # ---------------------------------------------------------------------------
38
- # CSS do projeto
39
- # ---------------------------------------------------------------------------
40
- APP_DIR = Path(__file__).resolve().parent
41
- STYLE_PATH = APP_DIR / "styles.css"
42
- CUSTOM_CSS = STYLE_PATH.read_text(encoding="utf-8") if STYLE_PATH.exists() else ""
43
-
44
-
45
-
46
- # ---------------------------------------------------------------------------
47
- # Warm-up agressivo — queremos que o primeiro request não pague cold-start
48
- # ---------------------------------------------------------------------------
49
- MODEL_READY: bool
50
- MODEL_ERROR: str | None
51
-
52
- try:
53
- warmup()
54
- MODEL_READY = True
55
- MODEL_ERROR = None
56
- log.info("Modelo carregado no startup. Device=%s", DEVICE)
57
- except Exception as exc: # noqa: BLE001 — queremos pegar qualquer falha de carregamento
58
- MODEL_READY = False
59
- MODEL_ERROR = f"{type(exc).__name__}: {exc}"
60
- log.error("Falha ao carregar modelo no startup:\n%s", traceback.format_exc())
61
-
62
-
63
- # ---------------------------------------------------------------------------
64
- # Helpers de apresentação
65
- # ---------------------------------------------------------------------------
66
- def _confidence_band(p: float) -> str:
67
- lo_a, hi_a = CONFIDENCE_BOUNDS_ALTA
68
- lo_m, hi_m = CONFIDENCE_BOUNDS_MEDIA
69
- if p <= lo_a or p >= hi_a:
70
- return "Alta"
71
- if p <= lo_m or p >= hi_m:
72
- return "Média"
73
- return "Baixa"
74
-
75
-
76
- def _label(p: float) -> str:
77
- return "Útil" if p >= THRESHOLD_UTIL else "Não-útil"
78
-
79
-
80
- def _score_card_html(p: float) -> str:
81
- """Card principal do resultado — usando classes CSS do projeto."""
82
- lbl = _label(p)
83
- band = _confidence_band(p)
84
-
85
- lbl_class = "notinhas-badge-util" if lbl == "Útil" else "notinhas-badge-nao-util"
86
-
87
- if band == "Alta":
88
- band_class = lbl_class
89
- elif band == "Média":
90
- band_class = "notinhas-badge-media"
91
- else:
92
- band_class = "notinhas-badge-baixa"
93
-
94
- return f"""
95
- <div class="notinhas-card">
96
- <div style="display:flex;justify-content:space-between;align-items:center;gap:12px;flex-wrap:wrap;">
97
- <div style="display:flex;gap:8px;flex-wrap:wrap;">
98
- <span class="notinhas-badge {lbl_class}">{lbl}</span>
99
- <span class="notinhas-badge {band_class}">Confiança {band}</span>
100
- </div>
101
-
102
- <div style="text-align:right;">
103
- <div class="notinhas-score-label">P(útil)</div>
104
- <div class="notinhas-score-value">{p:.4f}</div>
105
- </div>
106
- </div>
107
- </div>
108
- """
109
-
110
-
111
- def _contrib_color(v: float, v_max: float) -> str:
112
- if v_max <= 0:
113
- return "transparent"
114
- intensity = min(1.0, abs(v) / v_max)
115
- alpha = 0.15 + 0.65 * intensity # 0.15 .. 0.80
116
- if v > 0:
117
- return f"rgba(95, 168, 143, {alpha:.3f})" # verde (PALETA['util'] do notebook)
118
- return f"rgba(224, 123, 107, {alpha:.3f})" # coral (PALETA['nao_util'])
119
-
120
-
121
- def _highlighted_text_html(tokens: list[str], contribs: list[float]) -> str:
122
- if not tokens:
123
- return "<em>(sem palavras para destacar)</em>"
124
- v_max = max((abs(c) for c in contribs), default=1e-9) or 1e-9
125
- spans = []
126
- for tok, c in zip(tokens, contribs):
127
- bg = _contrib_color(c, v_max)
128
- spans.append(
129
- f'<span style="background:{bg};padding:2px 4px;border-radius:4px;'
130
- f'margin:0 1px;" title="Δ={c:+.6f}">{html.escape(tok)}</span>'
131
- )
132
- return (
133
- '<div style="font-size:15px;line-height:2;color:#212529;'
134
- 'font-family:system-ui, -apple-system, sans-serif;padding:4px;">'
135
- + " ".join(spans)
136
- + "</div>"
137
- )
138
-
139
-
140
- def _top_tokens_table_html(
141
- tokens: list[str], contribs: list[float], k: int = 5
142
- ) -> str:
143
- pairs = list(zip(tokens, contribs))
144
- pos = sorted([p for p in pairs if p[1] > 0], key=lambda x: -x[1])[:k]
145
- neg = sorted([p for p in pairs if p[1] < 0], key=lambda x: x[1])[:k]
146
-
147
- def _row(tok: str, v: float, side: str) -> str:
148
- color = "#1b4332" if side == "pos" else "#9d0208"
149
- return (
150
- f'<tr><td style="padding:5px 8px;color:{color};">'
151
- f"{html.escape(tok)}</td>"
152
- f'<td style="padding:5px 8px;text-align:right;color:{color};'
153
- f'font-variant-numeric:tabular-nums;">{v:+.6f}</td></tr>'
154
- )
155
-
156
- empty = '<tr><td colspan="2" style="padding:6px;color:#9aa1aa;"><em>—</em></td></tr>'
157
- pos_rows = "".join(_row(t, v, "pos") for t, v in pos) or empty
158
- neg_rows = "".join(_row(t, v, "neg") for t, v in neg) or empty
159
-
160
- all_same_side = (not neg and pos) or (not pos and neg)
161
- if not neg and pos:
162
- side_warning = (
163
- '<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
164
- '⚠️ <strong>Nenhuma palavra puxando para não-útil identificada.</strong> '
165
- 'O método leave-one-out compara a frase completa com cada ablação de uma palavra. '
166
- 'Quando todas as contribuições são positivas, a frase completa pontua '
167
- 'marginalmente <em>mais</em> do que qualquer subconjunto — comum em textos '
168
- 'muito curtos ou frases com sentido idiomático. '
169
- 'O texto permanece Não-útil porque P(útil) está longe do limiar (0.5); '
170
- 'o que o define é a <em>ausência</em> de características úteis '
171
- '(fontes, dados, neutralidade), não palavras negativas específicas.'
172
- '</p>'
173
- )
174
- elif not pos and neg:
175
- side_warning = (
176
- '<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
177
- '⚠️ <strong>Nenhuma palavra puxando para útil identificada.</strong> '
178
- 'Todas as palavras reduzem marginalmente P(útil) quando presentes.'
179
- '</p>'
180
- )
181
- else:
182
- side_warning = ""
183
-
184
- return f"""
185
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;margin-top:12px;
186
- font-family:system-ui, -apple-system, sans-serif;">
187
- <div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
188
- <div style="font-size:13px;font-weight:700;color:#1b4332;margin-bottom:6px;">
189
- Empurram para útil
190
- </div>
191
- <table style="width:100%;border-collapse:collapse;font-size:13px;">{pos_rows}</table>
192
- </div>
193
- <div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
194
- <div style="font-size:13px;font-weight:700;color:#9d0208;margin-bottom:6px;">
195
- Empurram para não-útil
196
- </div>
197
- <table style="width:100%;border-collapse:collapse;font-size:13px;">{neg_rows}</table>
198
- </div>
199
- </div>
200
- """ + side_warning
201
-
202
-
203
- # ---------------------------------------------------------------------------
204
- # Handlers — retornam HTML para a UI + JSON para a API
205
- # ---------------------------------------------------------------------------
206
- def handle_predict(text: str):
207
- text = (text or "").strip()
208
- if not text:
209
- return "<em>Forneça um texto.</em>", {"error": "empty_input"}
210
- if not MODEL_READY:
211
- err = MODEL_ERROR or "modelo indisponível"
212
- return (
213
- f"<em>Modelo indisponível: {html.escape(err)}</em>",
214
- {"error": "model_unavailable", "detail": err},
215
- )
216
-
217
- p = predict_one(text)
218
- return (
219
- _score_card_html(p),
220
- {
221
- "proba_util": p,
222
- "label": _label(p),
223
- "confidence_band": _confidence_band(p),
224
- },
225
- )
226
-
227
-
228
- def handle_explain(text: str):
229
- text = (text or "").strip()
230
- if not text:
231
- return "<em>Forneça um texto.</em>", "", "", {"error": "empty_input"}
232
- if not MODEL_READY:
233
- err = MODEL_ERROR or "modelo indisponível"
234
- return (
235
- f"<em>Modelo indisponível: {html.escape(err)}</em>",
236
- "",
237
- "",
238
- {"error": "model_unavailable", "detail": err},
239
- )
240
-
241
- result = explain_occlusion(text)
242
- p = result["proba_full"]
243
- tokens = result["tokens"]
244
- contribs = result["contributions"]
245
-
246
- return (
247
- _score_card_html(p),
248
- _highlighted_text_html(tokens, contribs),
249
- _top_tokens_table_html(tokens, contribs),
250
- {
251
- "proba_util": p,
252
- "label": _label(p),
253
- "confidence_band": _confidence_band(p),
254
- "tokens": tokens,
255
- "contributions": contribs,
256
- },
257
- )
258
-
259
-
260
- # ---------------------------------------------------------------------------
261
- # UI
262
- # ---------------------------------------------------------------------------
263
- EXAMPLE_UTIL = (
264
- "Segundo dados oficiais do Ministério da Saúde, o número citado no tweet é falso. "
265
- "A fonte correta pode ser conferida no link: https://www.gov.br/saude/..."
266
- )
267
- EXAMPLE_NAO = "Essa nota é claramente desnecessária, é opinião pessoal do autor."
268
-
269
- INTRO_MD = """
270
- # Notinhas — endpoint de utilidade (FT-Solo)
271
-
272
- Classificador de utilidade para **community notes em português**, baseado em
273
- **Qwen3-Embedding-4B + LoRA + cabeça linear** (modo fiel do FT-Solo, fold 01).
274
-
275
- - **Prever** — score + label + faixa de confiança.
276
- - **Explicar** — o mesmo + contribuição de cada palavra via leave-one-out.
277
- - **Sobre** — detalhes técnicos e limitações.
278
- """
279
-
280
-
281
- with gr.Blocks(
282
- title="Notinhas — endpoint de utilidade (FT-Solo)",
283
- theme=gr.themes.Base(),
284
- css=CUSTOM_CSS,
285
- ) as demo:
286
- gr.Markdown(INTRO_MD)
287
-
288
- if not MODEL_READY:
289
- gr.Markdown(
290
- f"""
291
- > ⚠️ **Modelo não carregou.** Detalhe: `{html.escape(MODEL_ERROR or '')}`
292
- >
293
- > Verifique que `artifacts/fold_01_adapter/` e `artifacts/fold_01_head.pt` estão presentes
294
- > no repositório do Space. Se o modelo base exigir autenticação, configure `HF_TOKEN` em
295
- > **Settings → Variables and secrets**.
296
- """
297
- )
298
-
299
- with gr.Tab("Prever"):
300
- with gr.Row():
301
- with gr.Column(scale=2):
302
- inp_p = gr.Textbox(
303
- label="Texto da nota",
304
- placeholder="Cole aqui o texto em português...",
305
- lines=7,
306
- max_lines=25,
307
- )
308
- btn_p = gr.Button("Prever", variant="primary")
309
- gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_p])
310
- with gr.Column(scale=3):
311
- out_card_p = gr.HTML(label="Resultado")
312
- out_json_p = gr.JSON(label="Resposta da API")
313
-
314
- btn_p.click(
315
- handle_predict,
316
- inputs=[inp_p],
317
- outputs=[out_card_p, out_json_p],
318
- api_name="predict",
319
- )
320
-
321
- with gr.Tab("Explicar"):
322
- with gr.Row():
323
- with gr.Column(scale=2):
324
- inp_e = gr.Textbox(
325
- label="Texto da nota",
326
- placeholder="Cole aqui o texto em português...",
327
- lines=7,
328
- max_lines=25,
329
- )
330
- btn_e = gr.Button("Explicar", variant="primary")
331
- gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_e])
332
- with gr.Column(scale=3):
333
- out_card_e = gr.HTML(label="Resultado")
334
- out_hl = gr.HTML(label="Contribuição por palavra")
335
- out_tbl = gr.HTML(label="Top tokens por lado")
336
- out_json_e = gr.JSON(label="Resposta da API")
337
-
338
- btn_e.click(
339
- handle_explain,
340
- inputs=[inp_e],
341
- outputs=[out_card_e, out_hl, out_tbl, out_json_e],
342
- api_name="explain",
343
- )
344
-
345
- with gr.Tab("Sobre"):
346
- gr.Markdown(
347
- f"""
348
- ### Detalhes técnicos
349
-
350
- - **Modelo base**: `Qwen/Qwen3-Embedding-4B` (embedding, 2.560 dims, last-token pooling).
351
- - **Adaptação**: LoRA treinado com alvo `label_binary_strict` (recorte A do projeto).
352
- - **Cabeça**: `nn.Linear(2560, 1)` → sigmoid.
353
- - **Prompt de instrução** (idêntico ao treino):
354
-
355
- > `Instruct: Represent the following Brazilian Portuguese community note for binary classification of helpfulness.`
356
- > `Query: <texto>`
357
-
358
- - **max_length**: 256 tokens.
359
- - **Dispositivo atual**: `{DEVICE}`.
360
- - **Fold servido**: 01 (melhor fold segundo o manifesto do pipeline).
361
-
362
- ### Método de explicação
363
-
364
- A aba **Explicar** usa **occlusion word-level** (leave-one-out): para cada palavra
365
- separada por espaço, calculamos `Δ = P(texto completo) − P(texto sem a palavra)`.
366
-
367
- - Δ positivo palavra puxando para **útil** (verde).
368
- - Δ negativo ⇒ palavra puxando para **não-útil** (coral).
369
-
370
- É uma aproximação rápida do SHAP Partition usado no notebook de explicabilidade
371
- (~1–2 s vs ~12–15 s em GPU), com resultados visualmente comparáveis para notas curtas.
372
-
373
- ### Limitações
374
-
375
- - O rótulo `helpful` mede **aceitabilidade bipartidária**, não qualidade editorial.
376
- A galeria curada do notebook mostra casos onde vizinhos semânticos idênticos
377
- recebem rótulos opostos por razões políticas.
378
- - Textos são truncados em 256 tokens.
379
- - Este endpoint serve um único fold. Para produção com ganho marginal de robustez,
380
- subir para ensemble dos 5 folds (média de probabilidades).
381
- """
382
- )
383
-
384
-
385
- if __name__ == "__main__":
386
- demo.queue(default_concurrency_limit=1).launch(
387
- server_name="0.0.0.0",
388
- server_port=int(os.environ.get("PORT", 7860)),
389
- show_api=True,
390
- )
 
1
+ """Gradio app — endpoint de utilidade para community notes em PT-BR.
2
+
3
+ Expõe:
4
+ - UI web com três abas: Prever / Explicar / Sobre.
5
+ - API HTTP em /gradio_api/call/predict e /gradio_api/call/explain (gerada
6
+ automaticamente pelo Gradio a partir dos api_name).
7
+
8
+ Para clientes Python, use gradio_client:
9
+
10
+ from gradio_client import Client
11
+ c = Client("<user>/<space>", hf_token="hf_...")
12
+ score = c.predict("texto da nota...", api_name="/predict")
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import html
17
+ import logging
18
+ import os
19
+ import traceback
20
+ from pathlib import Path
21
+
22
+ import gradio as gr
23
+
24
+ from config import (
25
+ CONFIDENCE_BOUNDS_ALTA,
26
+ CONFIDENCE_BOUNDS_MEDIA,
27
+ THRESHOLD_UTIL,
28
+ )
29
+ from inference import DEVICE, explain_occlusion, predict_one, warmup
30
+
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
34
+ )
35
+ log = logging.getLogger("app")
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # CSS do projeto
39
+ # ---------------------------------------------------------------------------
40
+ APP_DIR = Path(__file__).resolve().parent
41
+ STYLE_PATH = APP_DIR / "styles.css"
42
+ CUSTOM_CSS = STYLE_PATH.read_text(encoding="utf-8") if STYLE_PATH.exists() else ""
43
+
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Warm-up agressivo — queremos que o primeiro request não pague cold-start
48
+ # ---------------------------------------------------------------------------
49
+ MODEL_READY: bool
50
+ MODEL_ERROR: str | None
51
+
52
+ try:
53
+ warmup()
54
+ MODEL_READY = True
55
+ MODEL_ERROR = None
56
+ log.info("Modelo carregado no startup. Device=%s", DEVICE)
57
+ except Exception as exc: # noqa: BLE001 — queremos pegar qualquer falha de carregamento
58
+ MODEL_READY = False
59
+ MODEL_ERROR = f"{type(exc).__name__}: {exc}"
60
+ log.error("Falha ao carregar modelo no startup:\n%s", traceback.format_exc())
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Helpers de apresentação
65
+ # ---------------------------------------------------------------------------
66
+ def _confidence_band(p: float) -> str:
67
+ lo_a, hi_a = CONFIDENCE_BOUNDS_ALTA
68
+ lo_m, hi_m = CONFIDENCE_BOUNDS_MEDIA
69
+ if p <= lo_a or p >= hi_a:
70
+ return "Alta"
71
+ if p <= lo_m or p >= hi_m:
72
+ return "Média"
73
+ return "Baixa"
74
+
75
+
76
+ def _label(p: float) -> str:
77
+ return "Útil" if p >= THRESHOLD_UTIL else "Não-útil"
78
+
79
+
80
+ def _score_card_html(p: float) -> str:
81
+ """Card principal do resultado — usando classes CSS do projeto."""
82
+ lbl = _label(p)
83
+ band = _confidence_band(p)
84
+
85
+ lbl_class = "notinhas-badge-util" if lbl == "Útil" else "notinhas-badge-nao-util"
86
+
87
+ if band == "Alta":
88
+ band_class = lbl_class
89
+ elif band == "Média":
90
+ band_class = "notinhas-badge-media"
91
+ else:
92
+ band_class = "notinhas-badge-baixa"
93
+
94
+ return f"""
95
+ <div class="notinhas-card">
96
+ <div style="display:flex;justify-content:space-between;align-items:center;gap:12px;flex-wrap:wrap;">
97
+ <div style="display:flex;gap:8px;flex-wrap:wrap;">
98
+ <span class="notinhas-badge {lbl_class}">{lbl}</span>
99
+ <span class="notinhas-badge {band_class}">Confiança {band}</span>
100
+ </div>
101
+
102
+ <div style="text-align:right;">
103
+ <div class="notinhas-score-label">P(útil)</div>
104
+ <div class="notinhas-score-value">{p:.4f}</div>
105
+ </div>
106
+ </div>
107
+ </div>
108
+ """
109
+
110
+
111
+ def _contrib_color(v: float, v_max: float) -> str:
112
+ if v_max <= 0:
113
+ return "transparent"
114
+ intensity = min(1.0, abs(v) / v_max)
115
+ alpha = 0.15 + 0.65 * intensity # 0.15 .. 0.80
116
+ if v > 0:
117
+ return f"rgba(95, 168, 143, {alpha:.3f})" # verde (PALETA['util'] do notebook)
118
+ return f"rgba(224, 123, 107, {alpha:.3f})" # coral (PALETA['nao_util'])
119
+
120
+
121
+ def _highlighted_text_html(tokens: list[str], contribs: list[float]) -> str:
122
+ if not tokens:
123
+ return "<em>(sem palavras para destacar)</em>"
124
+ v_max = max((abs(c) for c in contribs), default=1e-9) or 1e-9
125
+ spans = []
126
+ for tok, c in zip(tokens, contribs):
127
+ bg = _contrib_color(c, v_max)
128
+ spans.append(
129
+ f'<span style="background:{bg};padding:2px 4px;border-radius:4px;'
130
+ f'margin:0 1px;" title="Δ={c:+.6f}">{html.escape(tok)}</span>'
131
+ )
132
+ return (
133
+ '<div style="font-size:15px;line-height:2;color:#212529;'
134
+ 'font-family:system-ui, -apple-system, sans-serif;padding:4px;">'
135
+ + " ".join(spans)
136
+ + "</div>"
137
+ )
138
+
139
+
140
+ def _top_tokens_table_html(
141
+ tokens: list[str], contribs: list[float], k: int = 5
142
+ ) -> str:
143
+ pairs = list(zip(tokens, contribs))
144
+ pos = sorted([p for p in pairs if p[1] > 0], key=lambda x: -x[1])[:k]
145
+ neg = sorted([p for p in pairs if p[1] < 0], key=lambda x: x[1])[:k]
146
+
147
+ def _row(tok: str, v: float, side: str) -> str:
148
+ color = "#1b4332" if side == "pos" else "#9d0208"
149
+ return (
150
+ f'<tr><td style="padding:5px 8px;color:{color};">'
151
+ f"{html.escape(tok)}</td>"
152
+ f'<td style="padding:5px 8px;text-align:right;color:{color};'
153
+ f'font-variant-numeric:tabular-nums;">{v:+.6f}</td></tr>'
154
+ )
155
+
156
+ empty = '<tr><td colspan="2" style="padding:6px;color:#9aa1aa;"><em>—</em></td></tr>'
157
+ pos_rows = "".join(_row(t, v, "pos") for t, v in pos) or empty
158
+ neg_rows = "".join(_row(t, v, "neg") for t, v in neg) or empty
159
+
160
+ all_same_side = (not neg and pos) or (not pos and neg)
161
+ if not neg and pos:
162
+ side_warning = (
163
+ '<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
164
+ '⚠️ <strong>Nenhuma palavra puxando para não-útil identificada.</strong> '
165
+ 'O método leave-one-out compara a frase completa com cada ablação de uma palavra. '
166
+ 'Quando todas as contribuições são positivas, a frase completa pontua '
167
+ 'marginalmente <em>mais</em> do que qualquer subconjunto — comum em textos '
168
+ 'muito curtos ou frases com sentido idiomático. '
169
+ 'O texto permanece Não-útil porque P(útil) está longe do limiar (0.5); '
170
+ 'o que o define é a <em>ausência</em> de características úteis '
171
+ '(fontes, dados, neutralidade), não palavras negativas específicas.'
172
+ '</p>'
173
+ )
174
+ elif not pos and neg:
175
+ side_warning = (
176
+ '<p style="font-size:12px;color:#6c757d;margin:10px 4px 0 4px;line-height:1.5;">'
177
+ '⚠️ <strong>Nenhuma palavra puxando para útil identificada.</strong> '
178
+ 'Todas as palavras reduzem marginalmente P(útil) quando presentes.'
179
+ '</p>'
180
+ )
181
+ else:
182
+ side_warning = ""
183
+
184
+ return f"""
185
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;margin-top:12px;
186
+ font-family:system-ui, -apple-system, sans-serif;">
187
+ <div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
188
+ <div style="font-size:13px;font-weight:700;color:#1b4332;margin-bottom:6px;">
189
+ Empurram para útil
190
+ </div>
191
+ <table style="width:100%;border-collapse:collapse;font-size:13px;">{pos_rows}</table>
192
+ </div>
193
+ <div style="background:#fcfcfd;border:1px solid #eef2f7;border-radius:12px;padding:12px;">
194
+ <div style="font-size:13px;font-weight:700;color:#9d0208;margin-bottom:6px;">
195
+ Empurram para não-útil
196
+ </div>
197
+ <table style="width:100%;border-collapse:collapse;font-size:13px;">{neg_rows}</table>
198
+ </div>
199
+ </div>
200
+ """ + side_warning
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Handlers — retornam HTML para a UI + JSON para a API
205
+ # ---------------------------------------------------------------------------
206
+ def handle_predict(text: str):
207
+ text = (text or "").strip()
208
+ if not text:
209
+ return "<em>Forneça um texto.</em>", {"error": "empty_input"}
210
+ if not MODEL_READY:
211
+ err = MODEL_ERROR or "modelo indisponível"
212
+ return (
213
+ f"<em>Modelo indisponível: {html.escape(err)}</em>",
214
+ {"error": "model_unavailable", "detail": err},
215
+ )
216
+
217
+ p = predict_one(text)
218
+ return (
219
+ _score_card_html(p),
220
+ {
221
+ "proba_util": p,
222
+ "label": _label(p),
223
+ "confidence_band": _confidence_band(p),
224
+ },
225
+ )
226
+
227
+
228
+ def handle_explain(text: str):
229
+ text = (text or "").strip()
230
+ if not text:
231
+ return "<em>Forneça um texto.</em>", "", "", {"error": "empty_input"}
232
+ if not MODEL_READY:
233
+ err = MODEL_ERROR or "modelo indisponível"
234
+ return (
235
+ f"<em>Modelo indisponível: {html.escape(err)}</em>",
236
+ "",
237
+ "",
238
+ {"error": "model_unavailable", "detail": err},
239
+ )
240
+
241
+ result = explain_occlusion(text)
242
+ p = result["proba_full"]
243
+ tokens = result["tokens"]
244
+ contribs = result["contributions"]
245
+
246
+ return (
247
+ _score_card_html(p),
248
+ _highlighted_text_html(tokens, contribs),
249
+ _top_tokens_table_html(tokens, contribs),
250
+ {
251
+ "proba_util": p,
252
+ "label": _label(p),
253
+ "confidence_band": _confidence_band(p),
254
+ "tokens": tokens,
255
+ "contributions": contribs,
256
+ },
257
+ )
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # UI
262
+ # ---------------------------------------------------------------------------
263
+ EXAMPLE_UTIL = (
264
+ "Segundo dados oficiais do Ministério da Saúde, o número citado no tweet é falso. "
265
+ "A fonte correta pode ser conferida no link: https://www.gov.br/saude/..."
266
+ )
267
+ EXAMPLE_NAO = "Essa nota é claramente desnecessária, é opinião pessoal do autor."
268
+
269
+ INTRO_MD = """
270
+ # Notinhas — endpoint de utilidade (FT-Solo)
271
+
272
+ Classificador de utilidade para **community notes em português**, baseado em
273
+ **bge-m3 (568M params) + LoRA + cabeça linear** (modo fiel do FT-Solo, fold 01).
274
+
275
+ - **Prever** — score + label + faixa de confiança.
276
+ - **Explicar** — o mesmo + contribuição de cada palavra via leave-one-out.
277
+ - **Sobre** — detalhes técnicos e limitações.
278
+ """
279
+
280
+
281
+ with gr.Blocks(
282
+ title="Notinhas — endpoint de utilidade (FT-Solo)",
283
+ theme=gr.themes.Base(),
284
+ css=CUSTOM_CSS,
285
+ ) as demo:
286
+ gr.Markdown(INTRO_MD)
287
+
288
+ if not MODEL_READY:
289
+ gr.Markdown(
290
+ f"""
291
+ > ⚠️ **Modelo não carregou.** Detalhe: `{html.escape(MODEL_ERROR or '')}`
292
+ >
293
+ > Verifique que `artifacts/fold_01_adapter/` e `artifacts/fold_01_head.pt` estão presentes
294
+ > no repositório do Space. Se o modelo base exigir autenticação, configure `HF_TOKEN` em
295
+ > **Settings → Variables and secrets**.
296
+ """
297
+ )
298
+
299
+ with gr.Tab("Prever"):
300
+ with gr.Row():
301
+ with gr.Column(scale=2):
302
+ inp_p = gr.Textbox(
303
+ label="Texto da nota",
304
+ placeholder="Cole aqui o texto em português...",
305
+ lines=7,
306
+ max_lines=25,
307
+ )
308
+ btn_p = gr.Button("Prever", variant="primary")
309
+ gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_p])
310
+ with gr.Column(scale=3):
311
+ out_card_p = gr.HTML(label="Resultado")
312
+ out_json_p = gr.JSON(label="Resposta da API")
313
+
314
+ btn_p.click(
315
+ handle_predict,
316
+ inputs=[inp_p],
317
+ outputs=[out_card_p, out_json_p],
318
+ api_name="predict",
319
+ )
320
+
321
+ with gr.Tab("Explicar"):
322
+ with gr.Row():
323
+ with gr.Column(scale=2):
324
+ inp_e = gr.Textbox(
325
+ label="Texto da nota",
326
+ placeholder="Cole aqui o texto em português...",
327
+ lines=7,
328
+ max_lines=25,
329
+ )
330
+ btn_e = gr.Button("Explicar", variant="primary")
331
+ gr.Examples(examples=[[EXAMPLE_UTIL], [EXAMPLE_NAO]], inputs=[inp_e])
332
+ with gr.Column(scale=3):
333
+ out_card_e = gr.HTML(label="Resultado")
334
+ out_hl = gr.HTML(label="Contribuição por palavra")
335
+ out_tbl = gr.HTML(label="Top tokens por lado")
336
+ out_json_e = gr.JSON(label="Resposta da API")
337
+
338
+ btn_e.click(
339
+ handle_explain,
340
+ inputs=[inp_e],
341
+ outputs=[out_card_e, out_hl, out_tbl, out_json_e],
342
+ api_name="explain",
343
+ )
344
+
345
+ with gr.Tab("Sobre"):
346
+ gr.Markdown(
347
+ f"""
348
+ ### Detalhes técnicos
349
+
350
+ - **Modelo base**: `BAAI/bge-m3` (embedding, 1.024 dims, mean pooling, 568M params).
351
+ - **Adaptação**: LoRA treinado com alvo `label_binary_strict` (recorte A do projeto).
352
+ - **Cabeça**: `nn.Linear(1024, 1)` → sigmoid.
353
+ - **Prompt de instrução**: nenhum — texto cru (bge-m3 não usa prefix de instrução).
354
+ - **max_length**: 256 tokens.
355
+ - **Dispositivo atual**: `{DEVICE}`.
356
+ - **Fold servido**: 01 (melhor fold segundo o manifesto do pipeline).
357
+
358
+ ### Método de explicação
359
+
360
+ A aba **Explicar** usa **occlusion word-level** (leave-one-out): para cada palavra
361
+ separada por espaço, calculamos `Δ = P(texto completo) − P(texto sem a palavra)`.
362
+
363
+ - Δ positivo ⇒ palavra puxando para **útil** (verde).
364
+ - Δ negativo palavra puxando para **não-útil** (coral).
365
+
366
+ É uma aproximação rápida do SHAP Partition usado no notebook de explicabilidade
367
+ (~1–2 s vs ~12–15 s em GPU), com resultados visualmente comparáveis para notas curtas.
368
+
369
+ ### Limitações
370
+
371
+ - O rótulo `helpful` mede **aceitabilidade bipartidária**, não qualidade editorial.
372
+ A galeria curada do notebook mostra casos onde vizinhos semânticos idênticos
373
+ recebem rótulos opostos por razões políticas.
374
+ - Textos são truncados em 256 tokens.
375
+ - Este endpoint serve um único fold. Para produção com ganho marginal de robustez,
376
+ subir para ensemble dos 5 folds (média de probabilidades).
377
+ """
378
+ )
379
+
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue(default_concurrency_limit=1).launch(
383
+ server_name="0.0.0.0",
384
+ server_port=int(os.environ.get("PORT", 7860)),
385
+ show_api=True,
386
+ )
 
 
 
 
artifacts/fold_01_adapter/README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- base_model: Qwen/Qwen3-Embedding-4B
3
  library_name: peft
4
  tags:
5
- - base_model:adapter:Qwen/Qwen3-Embedding-4B
6
  - lora
7
  - transformers
8
  ---
 
1
  ---
2
+ base_model: BAAI/bge-m3
3
  library_name: peft
4
  tags:
5
+ - base_model:adapter:BAAI/bge-m3
6
  - lora
7
  - transformers
8
  ---
artifacts/fold_01_adapter/adapter_config.json CHANGED
@@ -3,10 +3,10 @@
3
  "alpha_pattern": {},
4
  "arrow_config": null,
5
  "auto_mapping": {
6
- "base_model_class": "Qwen3Model",
7
- "parent_library": "transformers.models.qwen3.modeling_qwen3"
8
  },
9
- "base_model_name_or_path": "Qwen/Qwen3-Embedding-4B",
10
  "bias": "none",
11
  "corda_config": null,
12
  "ensure_weight_tying": false,
@@ -21,7 +21,7 @@
21
  "loftq_config": {},
22
  "lora_alpha": 32,
23
  "lora_bias": false,
24
- "lora_dropout": 0.05,
25
  "megatron_config": null,
26
  "megatron_core": "megatron.core",
27
  "modules_to_save": null,
@@ -32,13 +32,10 @@
32
  "rank_pattern": {},
33
  "revision": null,
34
  "target_modules": [
35
- "down_proj",
36
- "q_proj",
37
- "up_proj",
38
- "gate_proj",
39
- "k_proj",
40
- "o_proj",
41
- "v_proj"
42
  ],
43
  "target_parameters": null,
44
  "task_type": null,
 
3
  "alpha_pattern": {},
4
  "arrow_config": null,
5
  "auto_mapping": {
6
+ "base_model_class": "XLMRobertaModel",
7
+ "parent_library": "transformers.models.xlm_roberta.modeling_xlm_roberta"
8
  },
9
+ "base_model_name_or_path": "BAAI/bge-m3",
10
  "bias": "none",
11
  "corda_config": null,
12
  "ensure_weight_tying": false,
 
21
  "loftq_config": {},
22
  "lora_alpha": 32,
23
  "lora_bias": false,
24
+ "lora_dropout": 0.1,
25
  "megatron_config": null,
26
  "megatron_core": "megatron.core",
27
  "modules_to_save": null,
 
32
  "rank_pattern": {},
33
  "revision": null,
34
  "target_modules": [
35
+ "key",
36
+ "query",
37
+ "value",
38
+ "dense"
 
 
 
39
  ],
40
  "target_parameters": null,
41
  "task_type": null,
artifacts/fold_01_adapter/adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:326493c0cc026b088e80be86dc28fe61e21db919e52b602250e11abb6bac59b5
3
- size 132184864
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93d21f9a247eb8ce530e04b1f85055f7e405f5d0875ef646d6914de0d2a234a5
3
+ size 28482384
artifacts/fold_01_head.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a66a6088bce2a00b93377ecc4f8243e061eccdc4679f4920fd691b35a0523ab
3
- size 12365
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67ae73baff19fd870815c742171fe57d174bd984ccfd7f58751a37b44bbbda9c
3
+ size 6093
config.py CHANGED
@@ -1,54 +1,29 @@
1
- """Constantes compartilhadas pelo Space.
2
-
3
- Mantemos tudo em um único módulo para facilitar trocas (ex: substituir o fold
4
- selecionado, apontar para um tokenizer diferente em debug, etc.).
5
- """
6
  from __future__ import annotations
7
 
8
  import os
9
  from pathlib import Path
10
 
11
- # ---------------------------------------------------------------------------
12
- # Modelo base (baixado da Hugging Face no primeiro startup do Space)
13
- # ---------------------------------------------------------------------------
14
- MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
15
 
16
- # ---------------------------------------------------------------------------
17
- # Inferência — parâmetros IDÊNTICOS aos do notebook (seção 6, predict_from_text)
18
- # ---------------------------------------------------------------------------
19
  MAX_LENGTH = 256
20
  BATCH_SIZE = 8
21
 
22
- # Este prompt é parte do contrato do modelo — foi usado no fine-tuning.
23
- # Mudá-lo quebra o alinhamento entre o que o adapter viu e o que recebe agora.
24
- TASK_PROMPT = (
25
- "Represent the following Brazilian Portuguese community note "
26
- "for binary classification of helpfulness."
27
- )
28
 
29
- # ---------------------------------------------------------------------------
30
- # Paths dos artefatos (resolvidos a partir da raiz do repo do Space)
31
- # ---------------------------------------------------------------------------
32
  ROOT = Path(__file__).resolve().parent
33
  ARTIFACTS_DIR = ROOT / "artifacts"
34
-
35
- # Obrigatórios para servir predição.
36
  ADAPTER_PATH = ARTIFACTS_DIR / "fold_01_adapter"
37
- HEAD_PATH = ARTIFACTS_DIR / "fold_01_head.pt"
38
 
39
- # ---------------------------------------------------------------------------
40
- # Classificação (thresholds de apresentação — não afetam a probabilidade em si)
41
- # ---------------------------------------------------------------------------
42
  THRESHOLD_UTIL = 0.5
 
 
43
 
44
- # Faixas de confiança em função de p diretamente (evita imprecisão float do |p-0.5|):
45
- # Alta → p ≤ 0.10 ou p ≥ 0.90
46
- # Média → p ≤ 0.30 ou p ≥ 0.70
47
- # Baixa → 0.30 < p < 0.70
48
- CONFIDENCE_BOUNDS_ALTA = (0.10, 0.90) # fora desses limites = Alta
49
- CONFIDENCE_BOUNDS_MEDIA = (0.30, 0.70) # fora desses limites = Média
50
-
51
- # ---------------------------------------------------------------------------
52
- # Secrets (opcionais — definir em Settings → Secrets no Space)
53
- # ---------------------------------------------------------------------------
54
- HF_TOKEN = os.environ.get("HF_TOKEN") # só necessário se o modelo base virar gated
 
1
+ """Constantes compartilhadas pelo Space (bge-m3 FT-Solo)."""
 
 
 
 
2
  from __future__ import annotations
3
 
4
  import os
5
  from pathlib import Path
6
 
7
+ # Modelo base — bge-m3 (568M params, ~7x menor que Qwen3-4B)
8
+ MODEL_NAME = "BAAI/bge-m3"
 
 
9
 
10
+ # Inferência
 
 
11
  MAX_LENGTH = 256
12
  BATCH_SIZE = 8
13
 
14
+ # bge-m3 NÃO usa prompt de instrução. None mantém compatibilidade.
15
+ TASK_PROMPT = None
 
 
 
 
16
 
17
+ # Paths
 
 
18
  ROOT = Path(__file__).resolve().parent
19
  ARTIFACTS_DIR = ROOT / "artifacts"
 
 
20
  ADAPTER_PATH = ARTIFACTS_DIR / "fold_01_adapter"
21
+ HEAD_PATH = ARTIFACTS_DIR / "fold_01_head.pt"
22
 
23
+ # Classificação
 
 
24
  THRESHOLD_UTIL = 0.5
25
+ CONFIDENCE_BOUNDS_ALTA = (0.10, 0.90)
26
+ CONFIDENCE_BOUNDS_MEDIA = (0.30, 0.70)
27
 
28
+ # Secret opcional
29
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -1,11 +1,7 @@
1
  """Carregamento do modelo e inferência.
2
 
3
- Espelha o modo 'fiel' (faithful) do FT-Solo no notebook de explicabilidade:
4
- base Qwen3-Embedding-4B + LoRA do fold 01 + cabeça linear treinada no projeto.
5
-
6
- A função `predict_from_text` do notebook está reproduzida aqui com a mesma
7
- tokenização, mesmo pooling, mesmo dtype e mesmo prompt — para que as
8
- probabilidades retornadas sejam numericamente comparáveis às OOF salvas.
9
  """
10
  from __future__ import annotations
11
 
@@ -27,7 +23,6 @@ from config import (
27
  HF_TOKEN,
28
  MAX_LENGTH,
29
  MODEL_NAME,
30
- TASK_PROMPT,
31
  )
32
 
33
  logger = logging.getLogger(__name__)
@@ -40,42 +35,26 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
  if DEVICE == "cuda":
41
  AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
42
  else:
43
- # Em CPU usamos float16 nos pesos para caber em 16 GB de RAM (fp32 daria ~16 GB
44
- # nos pesos, sem sobrar para ativações). As operações em CPU rodam em fp32
45
- # via upcast automático; o dtype aqui só controla o armazenamento.
46
  # O autocast fica desligado (enabled=False abaixo) — fp16 ativo em CPU é instável.
47
  AMP_DTYPE = torch.float16
48
 
49
 
50
  # ---------------------------------------------------------------------------
51
- # Utilitários — idênticos ao notebook (seção 6)
52
  # ---------------------------------------------------------------------------
53
  def build_instruction_text(text: str) -> str:
54
- """Formata o texto no molde esperado pelo fine-tuning."""
55
- if not isinstance(text, str):
56
- text = ""
57
- return f"Instruct: {TASK_PROMPT}\nQuery: {text}"
58
 
59
 
60
- def last_token_pool(
61
  last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
62
  ) -> torch.Tensor:
63
- """Extrai o embedding do último token real.
64
-
65
- Com o tokenizer em padding_side='left', o último índice (-1) é sempre um
66
- token real para todos os elementos do batch, então podemos usar o atalho.
67
- Mantemos a branch de right-padding por paranoia.
68
- """
69
- left_padding = bool(
70
- (attention_mask[:, -1].sum() == attention_mask.shape[0]).item()
71
- )
72
- if left_padding:
73
- return last_hidden_states[:, -1]
74
- sequence_lengths = attention_mask.sum(dim=1) - 1
75
- return last_hidden_states[
76
- torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device),
77
- sequence_lengths,
78
- ]
79
 
80
 
81
  # ---------------------------------------------------------------------------
@@ -97,7 +76,7 @@ def load_model():
97
 
98
  logger.info("Carregando tokenizer de %s", MODEL_NAME)
99
  tokenizer = AutoTokenizer.from_pretrained(
100
- MODEL_NAME, padding_side="left", token=HF_TOKEN
101
  )
102
  if tokenizer.pad_token is None:
103
  tokenizer.pad_token = tokenizer.eos_token
@@ -179,7 +158,7 @@ def predict_batch(
179
  enabled=(DEVICE == "cuda"),
180
  ):
181
  out = encoder(**toks)
182
- emb = last_token_pool(out.last_hidden_state, toks["attention_mask"])
183
  emb = F.normalize(emb, p=2, dim=1)
184
  # Em CPU sem autocast, o encoder sai em fp16 e a head permanece em fp32 →
185
  # F.linear recusa. Igualar ao dtype da head resolve (inofensivo em GPU).
 
1
  """Carregamento do modelo e inferência.
2
 
3
+ Serve o FT-Solo com base BAAI/bge-m3 + LoRA do fold 01 + cabeça linear.
4
+ Pooling: mean sobre tokens reais (attention_mask). Sem prompt de instrução.
 
 
 
 
5
  """
6
  from __future__ import annotations
7
 
 
23
  HF_TOKEN,
24
  MAX_LENGTH,
25
  MODEL_NAME,
 
26
  )
27
 
28
  logger = logging.getLogger(__name__)
 
35
  if DEVICE == "cuda":
36
  AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
37
  else:
38
+ # Em CPU usamos float16 nos pesos para caber em RAM. As operações em CPU
39
+ # rodam em fp32 via upcast automático; o dtype aqui controla armazenamento.
 
40
  # O autocast fica desligado (enabled=False abaixo) — fp16 ativo em CPU é instável.
41
  AMP_DTYPE = torch.float16
42
 
43
 
44
  # ---------------------------------------------------------------------------
45
+ # Utilitários
46
  # ---------------------------------------------------------------------------
47
  def build_instruction_text(text: str) -> str:
48
+ """bge-m3 não usa prompt de instrução retorna o texto cru."""
49
+ return text if isinstance(text, str) else ""
 
 
50
 
51
 
52
+ def mean_pool(
53
  last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
54
  ) -> torch.Tensor:
55
+ """Mean pooling sobre os tokens reais (mascara padding)."""
56
+ mask = attention_mask.unsqueeze(-1).float()
57
+ return (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  # ---------------------------------------------------------------------------
 
76
 
77
  logger.info("Carregando tokenizer de %s", MODEL_NAME)
78
  tokenizer = AutoTokenizer.from_pretrained(
79
+ MODEL_NAME, padding_side="right", token=HF_TOKEN
80
  )
81
  if tokenizer.pad_token is None:
82
  tokenizer.pad_token = tokenizer.eos_token
 
158
  enabled=(DEVICE == "cuda"),
159
  ):
160
  out = encoder(**toks)
161
+ emb = mean_pool(out.last_hidden_state, toks["attention_mask"])
162
  emb = F.normalize(emb, p=2, dim=1)
163
  # Em CPU sem autocast, o encoder sai em fp16 e a head permanece em fp32 →
164
  # F.linear recusa. Igualar ao dtype da head resolve (inofensivo em GPU).