JasonFinley0821 commited on
Commit
58428bb
·
1 Parent(s): 32728eb

feat : update fpn class file

Browse files
Files changed (1) hide show
  1. models/fpn_inception.py +13 -1
models/fpn_inception.py CHANGED
@@ -3,6 +3,9 @@ import torch.nn as nn
3
  from torchsummary import summary
4
  from pretrainedmodels import inceptionresnetv2
5
  import torch.nn.functional as F
 
 
 
6
 
7
  class FPNHead(nn.Module):
8
  def __init__(self, num_in, num_mid, num_out):
@@ -91,7 +94,16 @@ class FPN(nn.Module):
91
  """
92
 
93
  super().__init__()
94
- self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
 
 
 
 
 
 
 
 
 
95
 
96
  self.enc0 = self.inception.conv2d_1a
97
  self.enc1 = nn.Sequential(
 
3
  from torchsummary import summary
4
  from pretrainedmodels import inceptionresnetv2
5
  import torch.nn.functional as F
6
+ import os
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  class FPNHead(nn.Module):
11
  def __init__(self, num_in, num_mid, num_out):
 
94
  """
95
 
96
  super().__init__()
97
+ #self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
98
+ self.inception = inceptionresnetv2(num_classes=1000, pretrained=None)
99
+ # 2️⃣ 載入本地權重
100
+ weight_path = os.path.join("model", "inceptionresnetv2_imagenet.pth")
101
+ if os.path.exists(weight_path):
102
+ state_dict = torch.load(weight_path, map_location=device)
103
+ self.inception.load_state_dict(state_dict)
104
+ print("Loaded local inceptionresnetv2_imagenet.pth successfully!")
105
+ else:
106
+ print(f"Warning: {weight_path} not found. Using randomly initialized weights.")
107
 
108
  self.enc0 = self.inception.conv2d_1a
109
  self.enc1 = nn.Sequential(