Update model.py
Browse files
model.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
import
|
| 4 |
-
from safetensors.torch import load_file
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class ConvBlock(nn.Module):
|
| 7 |
def __init__(self, in_c, out_c):
|
| 8 |
super().__init__()
|
|
@@ -16,9 +24,12 @@ class ConvBlock(nn.Module):
|
|
| 16 |
)
|
| 17 |
def forward(self, x): return self.conv(x)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
# Encoder
|
| 23 |
self.enc1 = ConvBlock(3, 32)
|
| 24 |
self.pool = nn.MaxPool2d(2)
|
|
@@ -47,17 +58,4 @@ class AlphaUNet(nn.Module):
|
|
| 47 |
d1 = torch.cat([d1, e1], dim=1)
|
| 48 |
d1 = self.dec1(d1)
|
| 49 |
|
| 50 |
-
return self.sigmoid(self.final(d1))
|
| 51 |
-
|
| 52 |
-
# Метод для удобной загрузки
|
| 53 |
-
@classmethod
|
| 54 |
-
def from_pretrained(cls, path="."):
|
| 55 |
-
model = cls()
|
| 56 |
-
# Грузим Safetensors
|
| 57 |
-
try:
|
| 58 |
-
state_dict = load_file(f"{path}/model.safetensors")
|
| 59 |
-
model.load_state_dict(state_dict)
|
| 60 |
-
print(">>> AlphaDepth loaded from .safetensors")
|
| 61 |
-
except FileNotFoundError:
|
| 62 |
-
print("Error: model.safetensors not found.")
|
| 63 |
-
return model
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
| 4 |
|
| 5 |
+
# 1. СНАЧАЛА ОПРЕДЕЛЯЕМ КОНФИГ
|
| 6 |
+
class AlphaDepthConfig(PretrainedConfig):
|
| 7 |
+
model_type = "alpha-depth"
|
| 8 |
+
|
| 9 |
+
def __init__(self, input_size=[3, 128, 128], **kwargs):
|
| 10 |
+
self.input_size = input_size
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
|
| 13 |
+
# 2. ВСПОМОГАТЕЛЬНЫЕ БЛОКИ
|
| 14 |
class ConvBlock(nn.Module):
|
| 15 |
def __init__(self, in_c, out_c):
|
| 16 |
super().__init__()
|
|
|
|
| 24 |
)
|
| 25 |
def forward(self, x): return self.conv(x)
|
| 26 |
|
| 27 |
+
# 3. САМА МОДЕЛЬ (Наследуемся от PreTrainedModel!)
|
| 28 |
+
class AlphaUNet(PreTrainedModel):
|
| 29 |
+
config_class = AlphaDepthConfig
|
| 30 |
+
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super().__init__(config)
|
| 33 |
# Encoder
|
| 34 |
self.enc1 = ConvBlock(3, 32)
|
| 35 |
self.pool = nn.MaxPool2d(2)
|
|
|
|
| 58 |
d1 = torch.cat([d1, e1], dim=1)
|
| 59 |
d1 = self.dec1(d1)
|
| 60 |
|
| 61 |
+
return self.sigmoid(self.final(d1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|