airlabshare commited on
Commit
fe33821
·
verified ·
1 Parent(s): 69f5470

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +46 -4
model.py CHANGED
@@ -43,17 +43,49 @@ class AnyThermalSegmentationModel(PreTrainedModel):
43
  super().__init__(config)
44
  self.backbone = Dinov2Model(config)
45
 
46
- # Wrapped in .model to match checkpoint "head.model.3.weight"
47
  self.head = nn.Module()
48
  self.head.model = nn.Sequential(
49
  nn.Conv2d(config.hidden_size, 64, kernel_size=3, padding=1),
50
  nn.ReLU(inplace=True),
51
- nn.Dropout2d(p=0.0), # Index 2 (Placeholder for structure)
52
  nn.Conv2d(64, config.num_labels, kernel_size=1)
53
  )
 
 
 
 
 
54
  self.post_init()
55
 
56
- def forward(self, pixel_values, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  outputs = self.backbone(pixel_values, **kwargs)
58
  features = outputs.last_hidden_state[:, 1:, :]
59
  B, L, C = features.shape
@@ -61,7 +93,17 @@ class AnyThermalSegmentationModel(PreTrainedModel):
61
  features = features.permute(0, 2, 1).reshape(B, C, H, W)
62
 
63
  logits = self.head.model(features)
64
- return F.interpolate(logits, scale_factor=14, mode='bilinear', align_corners=False)
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  # =============================================================================
 
43
  super().__init__(config)
44
  self.backbone = Dinov2Model(config)
45
 
46
+ # Head definition matches your NonlinearHead64
47
  self.head = nn.Module()
48
  self.head.model = nn.Sequential(
49
  nn.Conv2d(config.hidden_size, 64, kernel_size=3, padding=1),
50
  nn.ReLU(inplace=True),
51
+ nn.Dropout2d(p=0.0),
52
  nn.Conv2d(64, config.num_labels, kernel_size=1)
53
  )
54
+
55
+ # Define Normalization constants as buffers so they move to GPU automatically
56
+ self.register_buffer("norm_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
57
+ self.register_buffer("norm_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
58
+
59
  self.post_init()
60
 
61
+ def preprocess_input(self, x):
62
+ """
63
+ Replicates preprocess_dinov2:
64
+ 1. Resize to nearest multiple of 14
65
+ 2. Normalize with ViT stats
66
+ """
67
+ B, C, H, W = x.shape
68
+ patch_size = 14
69
+
70
+ # 1. Dynamic Resize (Snap to grid)
71
+ new_H = (H // patch_size) * patch_size
72
+ new_W = (W // patch_size) * patch_size
73
+
74
+ if new_H != H or new_W != W:
75
+ x = F.interpolate(x, size=(new_H, new_W), mode='bilinear', align_corners=False)
76
+
77
+ # 2. Normalize
78
+
79
+ if x.max() > 1.0: x = x / 255.0
80
+
81
+ x = (x - self.norm_mean) / self.norm_std
82
+ return x
83
+
84
+ def forward(self, pixel_values, labels=None, **kwargs):
85
+ # --- APPLY PREPROCESSING HERE ---
86
+ pixel_values = self.preprocess_input(pixel_values)
87
+ # --------------------------------
88
+
89
  outputs = self.backbone(pixel_values, **kwargs)
90
  features = outputs.last_hidden_state[:, 1:, :]
91
  B, L, C = features.shape
 
93
  features = features.permute(0, 2, 1).reshape(B, C, H, W)
94
 
95
  logits = self.head.model(features)
96
+
97
+ # Upscale back to input size
98
+ logits = F.interpolate(logits, size=pixel_values.shape[-2:], mode='bilinear', align_corners=False)
99
+
100
+ loss = None
101
+ if labels is not None:
102
+ loss_fct = nn.CrossEntropyLoss()
103
+ loss = loss_fct(logits, labels)
104
+ return {"loss": loss, "logits": logits}
105
+
106
+ return logits
107
 
108
 
109
  # =============================================================================