communitynotesbr / inference.py
histlearn's picture
refactor: single-fold fold_04 + Platt scaling (remove ensemble)
1c84192 verified
"""Carregamento do modelo e inferência (bge-m3 FT-Solo, single-fold calibrado).
Platt scaling pós-treino: P_calib = sigmoid(CALIB_A * logit(P_raw) + CALIB_B).
Com CALIB_A=1.0, CALIB_B=0.0 (defaults) a transformação é identidade.
"""
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Iterable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer
from config import (
ADAPTER_PATH,
BATCH_SIZE,
CALIB_A,
CALIB_B,
HEAD_PATH,
HF_TOKEN,
MAX_LENGTH,
MODEL_NAME,
)
logger = logging.getLogger(__name__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
AMP_DTYPE = (
(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
if DEVICE == "cuda"
else torch.float16
)
def build_instruction_text(text: str) -> str:
"""bge-m3 não usa prompt de instrução — retorna o texto cru."""
return text if isinstance(text, str) else ""
def mean_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Mean pooling sobre os tokens reais (mascara padding)."""
mask = attention_mask.unsqueeze(-1).float()
return (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
@lru_cache(maxsize=1)
def load_model():
"""Retorna (tokenizer, encoder, head). Carregado uma única vez por processo."""
if not ADAPTER_PATH.exists():
raise FileNotFoundError(f"Adapter LoRA não encontrado em {ADAPTER_PATH}.")
if not HEAD_PATH.exists():
raise FileNotFoundError(f"Cabeça classificadora não encontrada em {HEAD_PATH}.")
logger.info("Carregando tokenizer de %s", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, padding_side="right", token=HF_TOKEN
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Carregando encoder base %s (dtype=%s, device=%s)", MODEL_NAME, AMP_DTYPE, DEVICE)
base_encoder = AutoModel.from_pretrained(
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=AMP_DTYPE, token=HF_TOKEN
).to(DEVICE)
logger.info("Anexando adapter LoRA de %s", ADAPTER_PATH)
encoder = PeftModel.from_pretrained(
base_encoder, str(ADAPTER_PATH), is_trainable=False
).to(DEVICE)
encoder.eval()
logger.info("Carregando cabeça linear de %s", HEAD_PATH)
payload = torch.load(HEAD_PATH, map_location="cpu")
head_state = payload.get("state_dict", payload) if isinstance(payload, dict) else payload
in_feat = int(head_state["weight"].shape[1])
head = nn.Linear(in_feat, 1)
head.load_state_dict(head_state)
head = head.to(DEVICE).eval()
logger.info("Modelo pronto. In_features da cabeça: %d", in_feat)
return tokenizer, encoder, head
def warmup() -> None:
"""Força carregamento imediato para evitar cold-start."""
load_model()
@torch.no_grad()
def predict_batch(texts: Iterable[str], batch_size: int = BATCH_SIZE) -> np.ndarray:
"""Probabilidade calibrada de 'útil' para cada texto. Shape (N,)."""
tokenizer, encoder, head = load_model()
if isinstance(texts, str):
texts = [texts]
texts = list(texts)
if not texts:
return np.zeros(0, dtype=np.float64)
preds = []
autocast_device = "cuda" if DEVICE == "cuda" else "cpu"
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
instr = [build_instruction_text(t) for t in batch]
toks = tokenizer(
instr, padding=True, truncation=True,
max_length=MAX_LENGTH, return_tensors="pt",
).to(DEVICE)
with torch.inference_mode(), torch.autocast(
device_type=autocast_device, dtype=AMP_DTYPE, enabled=(DEVICE == "cuda")
):
out = encoder(**toks)
emb = mean_pool(out.last_hidden_state, toks["attention_mask"])
emb = F.normalize(emb, p=2, dim=1)
# Em CPU sem autocast, encoder fp16 + head fp32 → cast necessário
logits = head(emb.to(head.weight.dtype)).squeeze(-1)
p = torch.sigmoid(logits).float().cpu().numpy()
preds.append(p)
p_raw = np.clip(np.concatenate(preds).astype(np.float64), 1e-6, 1 - 1e-6)
# Platt scaling: P_calib = sigmoid(A * logit(P_raw) + B)
# sigmoid(x) = 1/(1+exp(-x)) — sinal negativo obrigatório no exp
if CALIB_A != 1.0 or CALIB_B != 0.0:
logit_raw = np.log(p_raw / (1.0 - p_raw))
return 1.0 / (1.0 + np.exp(-(CALIB_A * logit_raw + CALIB_B)))
return p_raw
def predict_one(text: str) -> float:
"""Atalho: probabilidade calibrada para um único texto."""
return float(predict_batch([text])[0])
def explain_occlusion(text: str, batch_size: int = BATCH_SIZE) -> dict:
"""Leave-one-out por palavra. Δ = P(texto) − P(texto sem a palavra)."""
words = text.split()
if not words:
p = predict_one(text)
return {"proba_full": p, "tokens": [], "contributions": []}
variants = [" ".join(words[:i] + words[i + 1:]) for i in range(len(words))]
probs = predict_batch([text] + variants, batch_size=batch_size)
p_full = float(probs[0])
return {"proba_full": p_full, "tokens": words,
"contributions": (p_full - probs[1:]).tolist()}