File size: 5,583 Bytes
4d7f570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Cofiber detection threshold circuit — standalone inference.

Loads model.safetensors and runs multi-scale object detection on a
feature tensor from any frozen vision backbone.

Usage:
    import torch
    from model import CofiberDetector

    detector = CofiberDetector.from_safetensors("model.safetensors")
    features = torch.randn(768, 40, 40)  # from frozen ViT at stride 16
    detections = detector.detect(features, score_thresh=0.3)
"""

import json
from pathlib import Path

import numpy as np
import torch
from safetensors.torch import load_file


def heaviside(x):
    return (x >= 0).float()


class CofiberDetector:
    """Depth-3 threshold circuit for multi-scale object detection.

    Layer 0: Average pool 2x (fixed weights, zero parameters)
    Layer 1: Cofiber = x - upsample(pool(x)) (fixed weights, zero parameters)
    Layer 2: Classify = H(w . cofib + b) (INT8 learned weights)
    """

    def __init__(self, prototypes, biases, scale_factor, n_scales=3):
        self.prototypes = prototypes
        self.biases = biases
        self.scale_factor = scale_factor
        self.n_scales = n_scales
        self.num_classes = prototypes.shape[0]
        self.feat_dim = prototypes.shape[1]

    @classmethod
    def from_safetensors(cls, path):
        tensors = load_file(str(path))
        prototypes = tensors["classify.weight"]
        biases = tensors["classify.bias"]
        scale_factor = tensors["classify.scale_factor"].item()
        # Dequantize from INT8 representation
        prototypes = prototypes / scale_factor
        biases = biases / scale_factor
        return cls(prototypes.numpy(), biases.numpy(), scale_factor)

    def _pool(self, x):
        """Layer 0: Average pool 2x. Fixed weights {0.25, 0.25, 0.25, 0.25}."""
        return (x[:, 0::2, 0::2] + x[:, 0::2, 1::2] +
                x[:, 1::2, 0::2] + x[:, 1::2, 1::2]) / 4

    def _cofiber(self, x, pooled):
        """Layer 1: Subtract. cofib = x - upsample(pool(x)). Fixed weights {1, -1}.
        Uses nearest-neighbor upsample for exact integer arithmetic."""
        h, w = x.shape[1], x.shape[2]
        upsampled = np.repeat(np.repeat(pooled, 2, axis=1), 2, axis=2)[:, :h, :w]
        return x - upsampled

    def _classify(self, features, stride):
        """Layer 2: H(w . features + b). One threshold gate per (location, class)."""
        C, h, w = features.shape
        flat = features.reshape(C, -1).T  # (H*W, C)
        logits = flat @ self.prototypes.T + self.biases  # (H*W, num_classes)

        # Heaviside: fire if logit >= 0
        detections = []
        for loc in range(h * w):
            for cls in range(self.num_classes):
                if logits[loc, cls] >= 0:
                    yi, xi = loc // w, loc % w
                    # Box: centered on patch location, size proportional to stride
                    cx = (xi + 0.5) * stride
                    cy = (yi + 0.5) * stride
                    half = stride * 2
                    detections.append({
                        "box": [cx - half, cy - half, cx + half, cy + half],
                        "score": float(logits[loc, cls]),
                        "label": int(cls),
                        "scale": stride,
                    })
        return detections

    def detect(self, features, score_thresh=0.3):
        """Run the full 3-layer circuit.

        Args:
            features: numpy array (C, H, W) from a frozen backbone at stride 16
            score_thresh: minimum logit value to report a detection

        Returns:
            list of {"box": [x1,y1,x2,y2], "score": float, "label": int, "scale": int}
        """
        if isinstance(features, torch.Tensor):
            features = features.numpy()

        all_dets = []
        f = features
        strides = [16 * (2 ** i) for i in range(self.n_scales)]

        for scale_idx in range(self.n_scales):
            stride = strides[scale_idx]
            if scale_idx < self.n_scales - 1:
                pooled = self._pool(f)
                cofib = self._cofiber(f, pooled)
                dets = self._classify(cofib, stride)
                f = pooled
            else:
                dets = self._classify(f, stride)

            all_dets.extend([d for d in dets if d["score"] >= score_thresh])

        return all_dets

    @property
    def param_count(self):
        return self.prototypes.size + self.biases.size

    @property
    def gate_count(self):
        """Total threshold gates across all scales."""
        total = 0
        h, w = 40, 40
        for s in range(self.n_scales):
            pool_gates = (h // 2) * (w // 2) * self.feat_dim
            subtract_gates = h * w * self.feat_dim
            classify_gates = h * w * self.num_classes
            total += pool_gates + subtract_gates + classify_gates
            h, w = h // 2, w // 2
        return total


if __name__ == "__main__":
    path = Path(__file__).parent / "model.safetensors"
    if not path.exists():
        print(f"model.safetensors not found at {path}")
        exit(1)

    detector = CofiberDetector.from_safetensors(path)
    print(f"Loaded cofiber detector: {detector.param_count:,} params, {detector.gate_count:,} gates")

    # Test on random features
    np.random.seed(42)
    features = np.random.randn(768, 40, 40).astype(np.float32)
    dets = detector.detect(features, score_thresh=0.0)
    print(f"Detections on random input (thresh=0.0): {len(dets)}")
    dets_thresh = detector.detect(features, score_thresh=0.3)
    print(f"Detections on random input (thresh=0.3): {len(dets_thresh)}")