File size: 1,615 Bytes
6578134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216f636
6578134
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from linear_probe import ConstitutionalProbe

class StreamingClassifier:
    """
    s_t = a * s{t-1} + (1-a)*z_t
    """

    def __init__(
            self,
            model: ConstitutionalProbe,
            threshold: float,
            ema_alpha: float = 0.9,
            device: str = "cpu"
    ):
        self.model = model
        self.model.eval()

        self.threshold = threshold
        self.alpha = ema_alpha
        self.device = device

    def _alpha_from_window(self, window_size: int) -> float:
        return 1.0-(2.0/(window_size+1))
    
    @torch.no_grad()
    def score_exchange(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
    ) -> tuple[bool, float, list[float]]:
        
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)

        logits = self.model(input_ids, attention_mask)

        logits = logits.squeeze(0)
        mask = attention_mask.squeeze(0)

        ema_score = 0.0
        peak_score = 0.0
        ema_trace = []
        flagged = False
        
        for t in range(logits.shape[0]):
            if mask[t].item() ==0:
                continue

            z_t = z_t = torch.sigmoid(logits[t]).item()  
            ema_score = self.alpha*ema_score + (1-self.alpha) *z_t

            ema_trace.append(ema_score)

            if ema_score > peak_score:
                peak_score = ema_score

            if ema_score > self.threshold:
                flagged = True

                
        return flagged, peak_score, ema_trace