Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from linear_probe import ConstitutionalProbe | |
| class StreamingClassifier: | |
| """ | |
| s_t = a * s{t-1} + (1-a)*z_t | |
| """ | |
| def __init__( | |
| self, | |
| model: ConstitutionalProbe, | |
| threshold: float, | |
| ema_alpha: float = 0.9, | |
| device: str = "cpu" | |
| ): | |
| self.model = model | |
| self.model.eval() | |
| self.threshold = threshold | |
| self.alpha = ema_alpha | |
| self.device = device | |
| def _alpha_from_window(self, window_size: int) -> float: | |
| return 1.0-(2.0/(window_size+1)) | |
| def score_exchange( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> tuple[bool, float, list[float]]: | |
| input_ids = input_ids.to(self.device) | |
| attention_mask = attention_mask.to(self.device) | |
| logits = self.model(input_ids, attention_mask) | |
| logits = logits.squeeze(0) | |
| mask = attention_mask.squeeze(0) | |
| ema_score = 0.0 | |
| peak_score = 0.0 | |
| ema_trace = [] | |
| flagged = False | |
| for t in range(logits.shape[0]): | |
| if mask[t].item() ==0: | |
| continue | |
| z_t = z_t = torch.sigmoid(logits[t]).item() | |
| ema_score = self.alpha*ema_score + (1-self.alpha) *z_t | |
| ema_trace.append(ema_score) | |
| if ema_score > peak_score: | |
| peak_score = ema_score | |
| if ema_score > self.threshold: | |
| flagged = True | |
| return flagged, peak_score, ema_trace |