rocker417 commited on
Commit
35e14e9
·
verified ·
1 Parent(s): db28624

Upload modeling_sac.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sac.py +159 -0
modeling_sac.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAC Patch Segmenter model for outpost deployment.
2
+
3
+ Accepts PIL images directly, runs U-Net segmentation, and returns
4
+ patch detection score + mask fraction.
5
+
6
+ Usage (inside outpost):
7
+ result = model.predict(image=pil_image)
8
+ # returns {"score": 0.85, "mask_fraction": 0.12}
9
+
10
+ Reference: Liu et al., CVPR 2022, "Segment and Complete"
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Optional
16
+
17
+ import sys
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from PIL import Image
22
+ from torchvision import transforms
23
+ from transformers import PreTrainedModel
24
+
25
+ from .configuration_sac import SACPatchSegmenterConfig
26
+
27
+
28
+ def _log(msg):
29
+ print(f"[SAC-DEBUG] {msg}", file=sys.stderr, flush=True)
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # U-Net architecture (matches joellliu/SegmentAndComplete coco_at.pth)
34
+ # ---------------------------------------------------------------------------
35
+
36
+
37
+ class _DoubleConv(nn.Module):
38
+ def __init__(self, in_ch, out_ch, mid_ch=None):
39
+ super().__init__()
40
+ mid = mid_ch or out_ch
41
+ self.double_conv = nn.Sequential(
42
+ nn.Conv2d(in_ch, mid, 3, padding=1), nn.BatchNorm2d(mid), nn.ReLU(inplace=True),
43
+ nn.Conv2d(mid, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.double_conv(x)
48
+
49
+
50
+ class _Down(nn.Module):
51
+ def __init__(self, in_ch, out_ch):
52
+ super().__init__()
53
+ self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), _DoubleConv(in_ch, out_ch))
54
+
55
+ def forward(self, x):
56
+ return self.maxpool_conv(x)
57
+
58
+
59
+ class _Up(nn.Module):
60
+ def __init__(self, in_ch, out_ch):
61
+ super().__init__()
62
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
63
+ self.conv = _DoubleConv(in_ch, out_ch, in_ch // 2)
64
+
65
+ def forward(self, x1, x2):
66
+ x1 = self.up(x1)
67
+ dy, dx = x2.size(2) - x1.size(2), x2.size(3) - x1.size(3)
68
+ x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2])
69
+ return self.conv(torch.cat([x2, x1], dim=1))
70
+
71
+
72
+ class _OutConv(nn.Module):
73
+ def __init__(self, in_ch, out_ch):
74
+ super().__init__()
75
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
76
+
77
+ def forward(self, x):
78
+ return self.conv(x)
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # HuggingFace PreTrainedModel wrapper
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ class SACPatchSegmenterModel(PreTrainedModel):
87
+ """SAC U-Net patch segmenter with integrated preprocessing.
88
+
89
+ Accepts PIL images, resizes to 416x416, runs U-Net segmentation,
90
+ and returns patch detection results.
91
+ """
92
+
93
+ config_class = SACPatchSegmenterConfig
94
+ supports_gradient_checkpointing = False
95
+
96
+ def __init__(self, config: SACPatchSegmenterConfig) -> None:
97
+ super().__init__(config)
98
+ bf = config.base_filter
99
+ self._input_size = config.input_size
100
+
101
+ # U-Net layers
102
+ self.inc = _DoubleConv(3, bf)
103
+ self.down1 = _Down(bf, bf * 2)
104
+ self.down2 = _Down(bf * 2, bf * 4)
105
+ self.down3 = _Down(bf * 4, bf * 8)
106
+ self.down4 = _Down(bf * 8, bf * 16 // 2)
107
+ self.up1 = _Up(bf * 16, bf * 8 // 2)
108
+ self.up2 = _Up(bf * 8, bf * 4 // 2)
109
+ self.up3 = _Up(bf * 4, bf * 2 // 2)
110
+ self.up4 = _Up(bf * 2, bf)
111
+ self.outc = _OutConv(bf, 1)
112
+
113
+ self._to_tensor = transforms.ToTensor()
114
+
115
+ def forward(self, pixel_values: Optional[torch.Tensor] = None, **kwargs):
116
+ """Standard forward pass. Also supports predict(image=pil)."""
117
+ if "image" in kwargs:
118
+ return self.predict(**kwargs)
119
+ if pixel_values is None:
120
+ raise ValueError("Provide pixel_values tensor or image=PIL")
121
+ return self._unet_forward(pixel_values)
122
+
123
+ def _unet_forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ x1 = self.inc(x)
125
+ x2 = self.down1(x1)
126
+ x3 = self.down2(x2)
127
+ x4 = self.down3(x3)
128
+ x5 = self.down4(x4)
129
+ x = self.up1(x5, x4)
130
+ x = self.up2(x, x3)
131
+ x = self.up3(x, x2)
132
+ x = self.up4(x, x1)
133
+ return self.outc(x)
134
+
135
+ @torch.no_grad()
136
+ def predict(self, image: Image.Image, **kwargs) -> dict:
137
+ """Accept a PIL image and return patch detection results."""
138
+ img = image.convert("RGB").resize(
139
+ (self._input_size, self._input_size), Image.Resampling.BILINEAR
140
+ )
141
+ tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype)
142
+
143
+ logits = self._unet_forward(tensor)
144
+ prob = torch.sigmoid(logits)
145
+
146
+ mask = (prob[0, 0] > 0.5).float()
147
+ mask_fraction = float(mask.sum().item()) / mask.numel()
148
+
149
+ if mask_fraction > 0.001:
150
+ score = min(1.0, mask_fraction * 10.0)
151
+ else:
152
+ score = float(prob.max().item()) * 0.5
153
+
154
+ return {"score": score, "mask_fraction": mask_fraction}
155
+
156
+ @torch.no_grad()
157
+ def score_image(self, image: Image.Image, **kwargs) -> dict:
158
+ """Alias for predict — matches outpost calling convention."""
159
+ return self.predict(image=image, **kwargs)