File size: 5,423 Bytes
1c84192
a2ad1d2
1c84192
 
a2ad1d2
 
 
 
 
1c84192
a2ad1d2
 
 
 
 
 
 
 
 
1c84192
a2ad1d2
76f9a5f
 
1c84192
a2ad1d2
 
 
 
 
 
 
1c84192
e3258a2
 
 
 
 
a2ad1d2
 
 
e3258a2
233b2df
a2ad1d2
 
76f9a5f
233b2df
 
 
a2ad1d2
 
 
1c84192
 
 
 
 
 
 
a2ad1d2
e3258a2
 
 
a2ad1d2
 
 
1c84192
a2ad1d2
e3258a2
a2ad1d2
 
1c84192
e3258a2
1c84192
e3258a2
 
a2ad1d2
1c84192
 
 
 
 
 
 
76f9a5f
1c84192
 
a2ad1d2
 
 
1c84192
 
a2ad1d2
 
 
76f9a5f
1c84192
 
e3258a2
a2ad1d2
 
 
 
 
 
1c84192
a2ad1d2
1c84192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ad1d2
 
 
1c84192
a2ad1d2
 
 
 
1c84192
a2ad1d2
 
 
 
1c84192
 
 
e3258a2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Carregamento do modelo e inferência (bge-m3 FT-Solo, single-fold calibrado).

Platt scaling pós-treino: P_calib = sigmoid(CALIB_A * logit(P_raw) + CALIB_B).
Com CALIB_A=1.0, CALIB_B=0.0 (defaults) a transformação é identidade.
"""
from __future__ import annotations

import logging
from functools import lru_cache
from typing import Iterable

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer

from config import (
    ADAPTER_PATH,
    BATCH_SIZE,
    CALIB_A,
    CALIB_B,
    HEAD_PATH,
    HF_TOKEN,
    MAX_LENGTH,
    MODEL_NAME,
)

logger = logging.getLogger(__name__)

DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"
AMP_DTYPE = (
    (torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
    if DEVICE == "cuda"
    else torch.float16
)


def build_instruction_text(text: str) -> str:
    """bge-m3 não usa prompt de instrução — retorna o texto cru."""
    return text if isinstance(text, str) else ""


def mean_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """Mean pooling sobre os tokens reais (mascara padding)."""
    mask = attention_mask.unsqueeze(-1).float()
    return (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)


@lru_cache(maxsize=1)
def load_model():
    """Retorna (tokenizer, encoder, head). Carregado uma única vez por processo."""
    if not ADAPTER_PATH.exists():
        raise FileNotFoundError(f"Adapter LoRA não encontrado em {ADAPTER_PATH}.")
    if not HEAD_PATH.exists():
        raise FileNotFoundError(f"Cabeça classificadora não encontrada em {HEAD_PATH}.")

    logger.info("Carregando tokenizer de %s", MODEL_NAME)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME, padding_side="right", token=HF_TOKEN
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info("Carregando encoder base %s (dtype=%s, device=%s)", MODEL_NAME, AMP_DTYPE, DEVICE)
    base_encoder = AutoModel.from_pretrained(
        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=AMP_DTYPE, token=HF_TOKEN
    ).to(DEVICE)

    logger.info("Anexando adapter LoRA de %s", ADAPTER_PATH)
    encoder = PeftModel.from_pretrained(
        base_encoder, str(ADAPTER_PATH), is_trainable=False
    ).to(DEVICE)
    encoder.eval()

    logger.info("Carregando cabeça linear de %s", HEAD_PATH)
    payload    = torch.load(HEAD_PATH, map_location="cpu")
    head_state = payload.get("state_dict", payload) if isinstance(payload, dict) else payload
    in_feat    = int(head_state["weight"].shape[1])
    head       = nn.Linear(in_feat, 1)
    head.load_state_dict(head_state)
    head = head.to(DEVICE).eval()

    logger.info("Modelo pronto. In_features da cabeça: %d", in_feat)
    return tokenizer, encoder, head


def warmup() -> None:
    """Força carregamento imediato para evitar cold-start."""
    load_model()


@torch.no_grad()
def predict_batch(texts: Iterable[str], batch_size: int = BATCH_SIZE) -> np.ndarray:
    """Probabilidade calibrada de 'útil' para cada texto. Shape (N,)."""
    tokenizer, encoder, head = load_model()

    if isinstance(texts, str):
        texts = [texts]
    texts = list(texts)
    if not texts:
        return np.zeros(0, dtype=np.float64)

    preds = []
    autocast_device = "cuda" if DEVICE == "cuda" else "cpu"

    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        instr = [build_instruction_text(t) for t in batch]
        toks  = tokenizer(
            instr, padding=True, truncation=True,
            max_length=MAX_LENGTH, return_tensors="pt",
        ).to(DEVICE)

        with torch.inference_mode(), torch.autocast(
            device_type=autocast_device, dtype=AMP_DTYPE, enabled=(DEVICE == "cuda")
        ):
            out  = encoder(**toks)
            emb  = mean_pool(out.last_hidden_state, toks["attention_mask"])
            emb  = F.normalize(emb, p=2, dim=1)
            # Em CPU sem autocast, encoder fp16 + head fp32 → cast necessário
            logits = head(emb.to(head.weight.dtype)).squeeze(-1)
            p = torch.sigmoid(logits).float().cpu().numpy()
        preds.append(p)

    p_raw = np.clip(np.concatenate(preds).astype(np.float64), 1e-6, 1 - 1e-6)

    # Platt scaling: P_calib = sigmoid(A * logit(P_raw) + B)
    # sigmoid(x) = 1/(1+exp(-x))  — sinal negativo obrigatório no exp
    if CALIB_A != 1.0 or CALIB_B != 0.0:
        logit_raw = np.log(p_raw / (1.0 - p_raw))
        return 1.0 / (1.0 + np.exp(-(CALIB_A * logit_raw + CALIB_B)))
    return p_raw


def predict_one(text: str) -> float:
    """Atalho: probabilidade calibrada para um único texto."""
    return float(predict_batch([text])[0])


def explain_occlusion(text: str, batch_size: int = BATCH_SIZE) -> dict:
    """Leave-one-out por palavra. Δ = P(texto) − P(texto sem a palavra)."""
    words = text.split()
    if not words:
        p = predict_one(text)
        return {"proba_full": p, "tokens": [], "contributions": []}
    variants  = [" ".join(words[:i] + words[i + 1:]) for i in range(len(words))]
    probs     = predict_batch([text] + variants, batch_size=batch_size)
    p_full    = float(probs[0])
    return {"proba_full": p_full, "tokens": words,
            "contributions": (p_full - probs[1:]).tolist()}