Spaces:
Sleeping
Sleeping
Commit ·
58428bb
1
Parent(s): 32728eb
feat : update fpn class file
Browse files- models/fpn_inception.py +13 -1
models/fpn_inception.py
CHANGED
|
@@ -3,6 +3,9 @@ import torch.nn as nn
|
|
| 3 |
from torchsummary import summary
|
| 4 |
from pretrainedmodels import inceptionresnetv2
|
| 5 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class FPNHead(nn.Module):
|
| 8 |
def __init__(self, num_in, num_mid, num_out):
|
|
@@ -91,7 +94,16 @@ class FPN(nn.Module):
|
|
| 91 |
"""
|
| 92 |
|
| 93 |
super().__init__()
|
| 94 |
-
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
self.enc0 = self.inception.conv2d_1a
|
| 97 |
self.enc1 = nn.Sequential(
|
|
|
|
| 3 |
from torchsummary import summary
|
| 4 |
from pretrainedmodels import inceptionresnetv2
|
| 5 |
import torch.nn.functional as F
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
class FPNHead(nn.Module):
|
| 11 |
def __init__(self, num_in, num_mid, num_out):
|
|
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
super().__init__()
|
| 97 |
+
#self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
| 98 |
+
self.inception = inceptionresnetv2(num_classes=1000, pretrained=None)
|
| 99 |
+
# 2️⃣ 載入本地權重
|
| 100 |
+
weight_path = os.path.join("model", "inceptionresnetv2_imagenet.pth")
|
| 101 |
+
if os.path.exists(weight_path):
|
| 102 |
+
state_dict = torch.load(weight_path, map_location=device)
|
| 103 |
+
self.inception.load_state_dict(state_dict)
|
| 104 |
+
print("Loaded local inceptionresnetv2_imagenet.pth successfully!")
|
| 105 |
+
else:
|
| 106 |
+
print(f"Warning: {weight_path} not found. Using randomly initialized weights.")
|
| 107 |
|
| 108 |
self.enc0 = self.inception.conv2d_1a
|
| 109 |
self.enc1 = nn.Sequential(
|