import torch import torch.nn as nn from linear_probe import ConstitutionalProbe class StreamingClassifier: """ s_t = a * s{t-1} + (1-a)*z_t """ def __init__( self, model: ConstitutionalProbe, threshold: float, ema_alpha: float = 0.9, device: str = "cpu" ): self.model = model self.model.eval() self.threshold = threshold self.alpha = ema_alpha self.device = device def _alpha_from_window(self, window_size: int) -> float: return 1.0-(2.0/(window_size+1)) @torch.no_grad() def score_exchange( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> tuple[bool, float, list[float]]: input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) logits = self.model(input_ids, attention_mask) logits = logits.squeeze(0) mask = attention_mask.squeeze(0) ema_score = 0.0 peak_score = 0.0 ema_trace = [] flagged = False for t in range(logits.shape[0]): if mask[t].item() ==0: continue z_t = z_t = torch.sigmoid(logits[t]).item() ema_score = self.alpha*ema_score + (1-self.alpha) *z_t ema_trace.append(ema_score) if ema_score > peak_score: peak_score = ema_score if ema_score > self.threshold: flagged = True return flagged, peak_score, ema_trace