File size: 3,794 Bytes
d9418d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Stage 1: output-channel pruning.

The full 768-D EUPE-ViT-B output token is produced as before, the final
LayerNorm runs as before, and then only the 100 dimensions the classifier
reads are retained. The classifier is fused into a single Linear(100, 1)
layer with ternary {+1, 0, -1} fixed weights and one free bias (threshold
expressed as negative bias). Inference is identical to Stage 0 by
construction. No compute savings; this stage just cleans the interface
and sets up the weight shapes that later stages will attack.

Usage:
    model = Stage1PersonClassifier.from_pretrained_argus('phanerozoic/argus')
    score, pred = model(image_tensor)
"""
import json, os, sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F


class Stage1PersonClassifier(nn.Module):
    """EUPE-ViT-B -> 100 dim slice -> ternary linear head -> binary decision.

    pos_dims and neg_dims are index tensors into the 768-D output. The
    classifier weight matrix stored in `retained_weight` has shape (1, 100)
    where positive-dim positions are +1 and negative-dim positions are -1.
    Bias equals the negated threshold.
    """

    def __init__(self, argus_model, pos_dims, neg_dims, threshold):
        super().__init__()
        self.backbone = argus_model.backbone
        retained = list(pos_dims) + list(neg_dims)
        self.register_buffer('retained_dims', torch.tensor(retained, dtype=torch.long))
        w = torch.zeros(1, len(retained))
        w[0, : len(pos_dims)] = 1.0
        w[0, len(pos_dims):] = -1.0
        self.register_buffer('retained_weight', w)
        # Stored as a free parameter so gradient descent could retune.
        self.threshold = nn.Parameter(torch.tensor(float(threshold)))
        self.D = 768

    @torch.inference_mode()
    def forward(self, x):
        """x: (B, 3, 768, 768) normalized (ImageNet stats).

        Returns (score, pred) where score is (B,) float and pred is (B,) bool.
        """
        with torch.autocast('cuda', dtype=torch.bfloat16):
            out = self.backbone.forward_features(x)
        patches = out['x_norm_patchtokens'].float()   # (B, 2304, 768)
        ln = F.layer_norm(patches, [self.D])
        pooled = ln.max(dim=1).values                  # (B, 768)
        retained = pooled.index_select(-1, self.retained_dims)  # (B, 100)
        score = F.linear(retained, self.retained_weight).squeeze(-1)  # (B,)
        pred = score > self.threshold
        return score, pred

    @classmethod
    def from_pretrained_argus(cls, repo_or_path='phanerozoic/argus',
                              classifier_json='classifier.json'):
        """Load Argus, read classifier.json, build the wrapper."""
        from transformers import AutoModel
        argus = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
        with open(classifier_json) as f:
            c = json.load(f)
        return cls(argus, c['pos_dims'], c['neg_dims'], c['threshold'])


if __name__ == '__main__':
    # Smoke test
    from transformers import AutoModel
    argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True)
    c = json.load(open(Path(__file__).parent / '..' / 'stage_0' / 'classifier.json'))
    m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold'])
    m = m.cuda().eval()
    n_all = sum(p.numel() for p in m.parameters())
    n_backbone = sum(p.numel() for p in m.backbone.parameters())
    n_head = n_all - n_backbone
    print(f'total params: {n_all:,}')
    print(f'backbone params: {n_backbone:,}')
    print(f'head params: {n_head} (one learnable threshold; weights are fixed buffers)')

    x = torch.randn(2, 3, 768, 768, device='cuda')
    score, pred = m(x)
    print(f'forward OK. score={score.tolist()}  pred={pred.tolist()}')