CC_linear_probe / inference.py
urbas's picture
Update inference.py
216f636 verified
import torch
import torch.nn as nn
from linear_probe import ConstitutionalProbe
class StreamingClassifier:
"""
s_t = a * s{t-1} + (1-a)*z_t
"""
def __init__(
self,
model: ConstitutionalProbe,
threshold: float,
ema_alpha: float = 0.9,
device: str = "cpu"
):
self.model = model
self.model.eval()
self.threshold = threshold
self.alpha = ema_alpha
self.device = device
def _alpha_from_window(self, window_size: int) -> float:
return 1.0-(2.0/(window_size+1))
@torch.no_grad()
def score_exchange(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[bool, float, list[float]]:
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
logits = self.model(input_ids, attention_mask)
logits = logits.squeeze(0)
mask = attention_mask.squeeze(0)
ema_score = 0.0
peak_score = 0.0
ema_trace = []
flagged = False
for t in range(logits.shape[0]):
if mask[t].item() ==0:
continue
z_t = z_t = torch.sigmoid(logits[t]).item()
ema_score = self.alpha*ema_score + (1-self.alpha) *z_t
ema_trace.append(ema_score)
if ema_score > peak_score:
peak_score = ema_score
if ema_score > self.threshold:
flagged = True
return flagged, peak_score, ema_trace