prostochel097 commited on
Commit
3d4b617
·
verified ·
1 Parent(s): f28ff29

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -19
model.py CHANGED
@@ -1,8 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
- import json
4
- from safetensors.torch import load_file
5
 
 
 
 
 
 
 
 
 
 
6
  class ConvBlock(nn.Module):
7
  def __init__(self, in_c, out_c):
8
  super().__init__()
@@ -16,9 +24,12 @@ class ConvBlock(nn.Module):
16
  )
17
  def forward(self, x): return self.conv(x)
18
 
19
- class AlphaUNet(nn.Module):
20
- def __init__(self):
21
- super().__init__()
 
 
 
22
  # Encoder
23
  self.enc1 = ConvBlock(3, 32)
24
  self.pool = nn.MaxPool2d(2)
@@ -47,17 +58,4 @@ class AlphaUNet(nn.Module):
47
  d1 = torch.cat([d1, e1], dim=1)
48
  d1 = self.dec1(d1)
49
 
50
- return self.sigmoid(self.final(d1))
51
-
52
- # Метод для удобной загрузки
53
- @classmethod
54
- def from_pretrained(cls, path="."):
55
- model = cls()
56
- # Грузим Safetensors
57
- try:
58
- state_dict = load_file(f"{path}/model.safetensors")
59
- model.load_state_dict(state_dict)
60
- print(">>> AlphaDepth loaded from .safetensors")
61
- except FileNotFoundError:
62
- print("Error: model.safetensors not found.")
63
- return model
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
 
4
 
5
+ # 1. СНАЧАЛА ОПРЕДЕЛЯЕМ КОНФИГ
6
+ class AlphaDepthConfig(PretrainedConfig):
7
+ model_type = "alpha-depth"
8
+
9
+ def __init__(self, input_size=[3, 128, 128], **kwargs):
10
+ self.input_size = input_size
11
+ super().__init__(**kwargs)
12
+
13
+ # 2. ВСПОМОГАТЕЛЬНЫЕ БЛОКИ
14
  class ConvBlock(nn.Module):
15
  def __init__(self, in_c, out_c):
16
  super().__init__()
 
24
  )
25
  def forward(self, x): return self.conv(x)
26
 
27
+ # 3. САМА МОДЕЛЬ (Наследуемся от PreTrainedModel!)
28
+ class AlphaUNet(PreTrainedModel):
29
+ config_class = AlphaDepthConfig
30
+
31
+ def __init__(self, config):
32
+ super().__init__(config)
33
  # Encoder
34
  self.enc1 = ConvBlock(3, 32)
35
  self.pool = nn.MaxPool2d(2)
 
58
  d1 = torch.cat([d1, e1], dim=1)
59
  d1 = self.dec1(d1)
60
 
61
+ return self.sigmoid(self.final(d1))