phoebehxf commited on
Commit
bc63015
·
1 Parent(s): d3a435a
Files changed (1) hide show
  1. models/model.py +5 -343
models/model.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  import os
5
  import clip
6
  import sys
 
7
  from models.seg_post_model.cellpose.models import CellposeModel
8
 
9
  from torchvision.ops import roi_align
@@ -53,99 +54,6 @@ class Counting_with_SD_features_track(nn.Module):
53
  self.adapter = adapter_roi_loca()
54
  self.regressor = regressor_with_SD_features_tra()
55
 
56
- class Counting_with_SD_features_loca_rand(nn.Module):
57
- def __init__(self, scale_factor, num_of_roi = 3):
58
- super(Counting_with_SD_features_loca_rand, self).__init__()
59
- self.adapter = adapter_roi_loca_rand(num_of_roi=num_of_roi)
60
- self.regressor = regressor_with_SD_features()
61
-
62
- class Counting_with_SD_features_loca_carpk(nn.Module):
63
- def __init__(self, scale_factor, num_of_roi = 3):
64
- super(Counting_with_SD_features_loca_carpk, self).__init__()
65
- self.adapter = adapter_roi_loca_carpk(num_of_roi=num_of_roi)
66
- self.regressor = regressor_with_SD_features()
67
-
68
- class Counting_with_SD_features_clip_carpk(nn.Module):
69
- def __init__(self, scale_factor, num_of_roi = 3):
70
- super(Counting_with_SD_features_clip_carpk, self).__init__()
71
- self.adapter = adapter_roi_clip_carpk(num_of_roi=num_of_roi)
72
- # self.regressor = regressor_with_SD_features()
73
-
74
- class Counting_with_SD_features_zero(nn.Module):
75
- def __init__(self, scale_factor):
76
- super(Counting_with_SD_features_zero, self).__init__()
77
- self.adapter = adapter_roi_zero()
78
- self.regressor = regressor_with_SD_features()
79
-
80
- class Counting_with_SD_features_zero_loca(nn.Module):
81
- def __init__(self, scale_factor):
82
- super(Counting_with_SD_features_zero_loca, self).__init__()
83
- self.adapter = adapter_roi_zero_loca()
84
- self.regressor = regressor_with_SD_features()
85
-
86
- class Counting_with_SD_features_zero_loca_self(nn.Module):
87
- def __init__(self, scale_factor):
88
- super(Counting_with_SD_features_zero_loca_self, self).__init__()
89
- self.adapter = adapter_roi_zero_loca()
90
- # self.regressor = regressor_with_SD_features_self()
91
- self.regressor = regressor_with_SD_features_latent()
92
-
93
- class Counting_with_SD_features_loca_v2(nn.Module):
94
- def __init__(self, scale_factor):
95
- super(Counting_with_SD_features_loca_v2, self).__init__()
96
- self.adapter = adapter_roi_loca_v2()
97
- # self.regressor = regressor_with_SD_features()
98
-
99
- class adapter1(nn.Module):
100
- def __init__(self):
101
- super(adapter1, self).__init__()
102
- self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
103
- self.pool = nn.MaxPool2d(2)
104
- self.fc = nn.Linear(128 * 64 * 64, 768)
105
- self.initialize_weights()
106
-
107
- def forward(self, x):
108
- x = self.conv1(x)
109
- x = self.pool(x)
110
- x = x.view(x.size(0), -1)
111
- x = self.fc(x)
112
- return x
113
-
114
- def initialize_weights(self):
115
- for m in self.modules():
116
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
117
- nn.init.xavier_normal_(m.weight)
118
- if m.bias is not None:
119
- nn.init.constant_(m.bias, 0)
120
-
121
- class adapter(nn.Module):
122
- def __init__(self, pool_size=[3, 3]):
123
- super(adapter, self).__init__()
124
- self.pool_size = pool_size
125
- self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
126
- self.pool = nn.MaxPool2d(2)
127
- self.fc = nn.Linear(256 * 3 * 3, 768)
128
- self.initialize_weights()
129
-
130
- def forward(self, xs):
131
- x_list = []
132
- for x in xs:
133
- x = F.adaptive_max_pool2d(x, self.pool_size, return_indices=False) # [1, 256, 3, 3]
134
- x_list.append(x)
135
- x_list = torch.cat(x_list, dim=0)
136
- x_list = torch.mean(x_list, dim=0, keepdim=True) # [1, 256, 3, 3]
137
- x = self.conv1(x_list)
138
- # x = self.pool(x)
139
- x = x.view(x.size(0), -1)
140
- x = self.fc(x)
141
- return x
142
-
143
- def initialize_weights(self):
144
- for m in self.modules():
145
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
146
- nn.init.xavier_normal_(m.weight)
147
- if m.bias is not None:
148
- nn.init.constant_(m.bias, 0)
149
 
150
  class adapter_roi(nn.Module):
151
  def __init__(self, pool_size=[3, 3]):
@@ -279,256 +187,6 @@ class adapter_roi_loca(nn.Module):
279
  nn.init.constant_(m.bias, 0)
280
 
281
 
282
- class adapter_roi_dino(nn.Module):
283
- def __init__(self, pool_size=[3, 3]):
284
- super(adapter_roi_dino, self).__init__()
285
- self.pool_size = pool_size
286
- # self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
287
- # self.pool = nn.MaxPool2d(2)
288
- self.fc = nn.Linear(1024, 768)
289
- self.initialize_weights()
290
- def forward(self, crops, dino_model):
291
- num_of_boxes = len(crops)
292
- feats = []
293
- for i in range(num_of_boxes):
294
- with torch.no_grad():
295
- feat = dino_model(crops[i])
296
-
297
- feats.append(feat)
298
- feats = torch.cat(feats, dim=0)
299
- feats = torch.mean(feats, dim=0)
300
- x = self.fc(feats)
301
- return x
302
- def initialize_weights(self):
303
- for m in self.modules():
304
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
305
- nn.init.xavier_normal_(m.weight)
306
- if m.bias is not None:
307
- nn.init.constant_(m.bias, 0)
308
-
309
-
310
-
311
- class adapter_roi_loca_v2(nn.Module):
312
- def __init__(self, pool_size=[3, 3]):
313
- super(adapter_roi_loca_v2, self).__init__()
314
- self.pool_size = pool_size
315
- self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
316
- self.pool = nn.MaxPool2d(2)
317
- self.fc = nn.Linear(256 * 3 * 3, 1024)
318
- self.initialize_weights()
319
- def forward(self, x, boxes):
320
- rois = []
321
- bs, _, h, w = x.shape
322
- boxes = torch.cat([
323
- torch.arange(
324
- bs, requires_grad=False
325
- ).to(boxes.device).repeat_interleave(3).reshape(-1, 1),
326
- boxes.flatten(0, 1),
327
- ], dim=1)
328
- rois = roi_align(
329
- x,
330
- boxes=boxes, output_size=3,
331
- spatial_scale=1.0 / 8, aligned=True
332
- )
333
- rois = torch.mean(rois, dim=0, keepdim=True)
334
- x = self.conv1(rois)
335
- x = x.view(x.size(0), -1)
336
- x = self.fc(x)
337
- return x
338
- def initialize_weights(self):
339
- for m in self.modules():
340
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
341
- nn.init.xavier_normal_(m.weight)
342
- if m.bias is not None:
343
- nn.init.constant_(m.bias, 0)
344
-
345
- class adapter_roi_zero(nn.Module):
346
- def __init__(self, reduction=4):
347
- super(adapter_roi_zero, self).__init__()
348
- self.fc1 = nn.Sequential(
349
- nn.Linear(768, 768 // reduction, bias=False),
350
- nn.ReLU()
351
- )
352
- self.fc2 = nn.Sequential(
353
- nn.Linear(768 // reduction, 768, bias=False),
354
- nn.ReLU()
355
- )
356
- self.initialize_weights()
357
- def forward(self, x):
358
- x1 = self.fc1(x)
359
- x1 = self.fc2(x1)
360
- return x + x1
361
- def initialize_weights(self):
362
- for m in self.modules():
363
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
364
- nn.init.xavier_normal_(m.weight)
365
- if m.bias is not None:
366
- nn.init.constant_(m.bias, 0)
367
-
368
- class adapter_roi_zero_loca(nn.Module):
369
- def __init__(self, reduction=4):
370
- super(adapter_roi_zero_loca, self).__init__()
371
- self.fc1 = nn.Sequential(
372
- nn.Linear(768, 768 // reduction, bias=False),
373
- nn.ReLU()
374
- )
375
- self.fc2 = nn.Sequential(
376
- nn.Linear(768 // reduction, 768, bias=False),
377
- nn.ReLU()
378
- )
379
-
380
- self.pool_size = (3, 3)
381
- self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
382
- self.pool = nn.MaxPool2d(2)
383
- self.fc = nn.Linear(256 * 3 * 3, 768)
384
-
385
- self.initialize_weights()
386
- def forward(self, feature, boxes, class_emb):
387
- x1 = self.fc1(class_emb)
388
- x1 = self.fc2(x1)
389
- class_emb = class_emb + x1
390
-
391
- rois = []
392
- bs, _, h, w = feature.shape
393
- n_box = boxes.shape[1]
394
- boxes = torch.cat([
395
- torch.arange(
396
- bs, requires_grad=False
397
- ).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
398
- boxes.flatten(0, 1),
399
- ], dim=1)
400
- rois = roi_align(
401
- feature,
402
- boxes=boxes, output_size=3,
403
- spatial_scale=1.0 / 8, aligned=True
404
- )
405
- # rois = torch.mean(rois, dim=0, keepdim=True)
406
- x = self.conv1(rois)
407
- x = x.view(x.size(0), -1)
408
- x = self.fc(x)
409
-
410
- if len(class_emb.shape) == 3:
411
- class_emb = class_emb.squeeze(1)
412
- dist = torch.cosine_similarity(class_emb, x) # [n_box]
413
- _, topk = torch.sort(dist[:10])
414
- x_topk = x[topk[:3], :]
415
- x_topk = torch.mean(x_topk, dim=0, keepdim=True)
416
- return x_topk + class_emb
417
-
418
- def vis(self, feature, boxes, class_emb):
419
- x1 = self.fc1(class_emb)
420
- x1 = self.fc2(x1)
421
- class_emb = class_emb + x1
422
-
423
- rois = []
424
- bs, _, h, w = feature.shape
425
- n_box = boxes.shape[1]
426
- boxes = torch.cat([
427
- torch.arange(
428
- bs, requires_grad=False
429
- ).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
430
- boxes.flatten(0, 1),
431
- ], dim=1)
432
- rois = roi_align(
433
- feature,
434
- boxes=boxes, output_size=3,
435
- spatial_scale=1.0 / 8, aligned=True
436
- )
437
- # rois = torch.mean(rois, dim=0, keepdim=True)
438
- x = self.conv1(rois)
439
- x = x.view(x.size(0), -1)
440
- x = self.fc(x)
441
-
442
- if len(class_emb.shape) == 3:
443
- class_emb = class_emb.squeeze(1)
444
- dist = torch.cosine_similarity(class_emb, x) # [n_box]
445
- _, topk = torch.sort(dist[:10])
446
- x_topk = x[topk[:3], :]
447
- x_topk = torch.mean(x_topk, dim=0, keepdim=True)
448
- return x_topk
449
-
450
- def initialize_weights(self):
451
- for m in self.modules():
452
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
453
- nn.init.xavier_normal_(m.weight)
454
- if m.bias is not None:
455
- nn.init.constant_(m.bias, 0)
456
-
457
- class adapter_roi_loca_rand(nn.Module):
458
- def __init__(self, pool_size=[3, 3],num_of_roi = 3):
459
- super(adapter_roi_loca_rand, self).__init__()
460
- self.pool_size = pool_size
461
- self.num_of_roi = num_of_roi
462
- self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
463
- self.pool = nn.MaxPool2d(2)
464
- self.fc = nn.Linear(256 * 3 * 3, 768)
465
-
466
- # # **new
467
- # self.fc1 = nn.Sequential(
468
- # nn.Linear(768, 768 // 4, bias=False),
469
- # nn.ReLU()
470
- # )
471
- # self.fc2 = nn.Sequential(
472
- # nn.Linear(768 // 4, 768, bias=False),
473
- # nn.ReLU()
474
- # )
475
- # #
476
- self.initialize_weights()
477
- def forward(self, x, boxes, rand_boxes):
478
- num_of_boxes = boxes.shape[1]
479
- bs, _, h, w = x.shape
480
- boxes = torch.cat([
481
- torch.arange(
482
- bs, requires_grad=False
483
- ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
484
- boxes.flatten(0, 1),
485
- ], dim=1)
486
- rois = roi_align(
487
- x,
488
- boxes=boxes, output_size=3,
489
- spatial_scale=1.0 / 8, aligned=True
490
- )
491
-
492
- # new
493
- num_of_boxes = rand_boxes.shape[1]
494
- bs, _, h, w = x.shape
495
- rand_boxes = torch.cat([
496
- torch.arange(
497
- bs, requires_grad=False
498
- ).to(rand_boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
499
- rand_boxes.flatten(0, 1),
500
- ], dim=1)
501
- rand_rois = roi_align(
502
- x,
503
- boxes=rand_boxes, output_size=3,
504
- spatial_scale=1.0 / 8, aligned=True
505
- )
506
-
507
- rois = torch.mean(rois, dim=0, keepdim=True)
508
-
509
- # new
510
- cos = torch.nn.CosineSimilarity(dim=1)
511
- dist = cos(rois.view(1, -1), rand_rois.view(num_of_boxes, -1)) # [n_box]
512
- _, topk = torch.sort(-dist)
513
- x_topk = rand_rois[topk[:3], ...]
514
- x_topk = torch.mean(x_topk, dim=0, keepdim=True)
515
-
516
- rois += x_topk
517
-
518
- x = self.conv1(rois)
519
- x = x.view(x.size(0), -1)
520
- x = self.fc(x)
521
- # new
522
- # x = self.fc1(x)
523
- # x = self.fc2(x)
524
- return x
525
-
526
- def initialize_weights(self):
527
- for m in self.modules():
528
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
529
- nn.init.xavier_normal_(m.weight)
530
- if m.bias is not None:
531
- nn.init.constant_(m.bias, 0)
532
 
533
 
534
  class regressor1(nn.Module):
@@ -723,6 +381,8 @@ class regressor_with_SD_features_seg_vit_c3(nn.Module):
723
 
724
 
725
  out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
 
 
726
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
727
  return out
728
 
@@ -763,6 +423,8 @@ class regressor_with_SD_features_tra(nn.Module):
763
  feat = x
764
 
765
  out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
 
 
766
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
767
  return out, 0., feat
768
 
 
4
  import os
5
  import clip
6
  import sys
7
+ import numpy as np
8
  from models.seg_post_model.cellpose.models import CellposeModel
9
 
10
  from torchvision.ops import roi_align
 
54
  self.adapter = adapter_roi_loca()
55
  self.regressor = regressor_with_SD_features_tra()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  class adapter_roi(nn.Module):
59
  def __init__(self, pool_size=[3, 3]):
 
187
  nn.init.constant_(m.bias, 0)
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
 
192
  class regressor1(nn.Module):
 
381
 
382
 
383
  out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
384
+ if out.dtype == np.uint16:
385
+ out = out.astype(np.int16)
386
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
387
  return out
388
 
 
423
  feat = x
424
 
425
  out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
426
+ if out.dtype == np.uint16:
427
+ out = out.astype(np.int16)
428
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
429
  return out, 0., feat
430