nguyenminh4099 commited on
Commit
e9c1ede
·
verified ·
1 Parent(s): 259c165

Delete modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +0 -129
modeling.py DELETED
@@ -1,129 +0,0 @@
1
- #
2
- # Copyright (c) 2025
3
- # Minh NGUYEN <vnguyen9@lakeheadu.ca>
4
- #
5
- import logging
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.init as init
9
-
10
- from transformers import PreTrainedModel
11
- from typing import Callable, Any, Optional
12
- from typing_extensions import Self
13
-
14
- from collections import OrderedDict
15
-
16
- from .configuration import AlexNetConfig
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- class AlexNet(PreTrainedModel):
22
- """The Alex Network class."""
23
-
24
- config_class = AlexNetConfig
25
-
26
- def __init__(self, config: Optional[AlexNetConfig] = None):
27
- config = config or AlexNetConfig()
28
-
29
- super().__init__(config)
30
-
31
- self.config = config
32
-
33
- self.feature_extractor = nn.Sequential(
34
- OrderedDict(
35
- [
36
- ("conv1", nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4, padding=2)),
37
- ("relu1", nn.ReLU(inplace=True)),
38
- ("maxpool1", nn.MaxPool2d(kernel_size=3, stride=2)),
39
-
40
- ("conv2", nn.Conv2d(in_channels=64, out_channels=192, kernel_size=5, padding=2, bias=True)),
41
- ("relu2", nn.ReLU(inplace=True)),
42
- ("maxpool2", nn.MaxPool2d(kernel_size=3, stride=2)),
43
-
44
- ("conv3", nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1)),
45
- ("reul3", nn.ReLU(inplace=True)),
46
-
47
- ("conv4", nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1)),
48
- ("relu4", nn.ReLU(inplace=True)),
49
-
50
- ("conv5", nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)),
51
- ("relu5", nn.ReLU(inplace=True)),
52
- ("maxpool3", nn.MaxPool2d(kernel_size=3, stride=2)),
53
- ]
54
- )
55
- )
56
-
57
- self.avgpool = nn.AdaptiveAvgPool2d(output_size=(6, 6))
58
-
59
- self.head = nn.Sequential(
60
- OrderedDict(
61
- [
62
- ("dropout1", nn.Dropout(p=0.5)),
63
- ("linear1", nn.Linear(in_features=256 * 6 * 6, out_features=4096, bias=True)),
64
- ("relu1", nn.ReLU(inplace=True)),
65
- ("dropout2", nn.Dropout(p=0.5)),
66
- ("linear2", nn.Linear(in_features=4096, out_features=4096, bias=True)),
67
- ("relu2", nn.ReLU(inplace=True)),
68
- ("linear3", nn.Linear(in_features=4096, out_features=1000, bias=True)),
69
- ]
70
- )
71
- )
72
-
73
- def forward(self, x: torch.Tensor) -> torch.Tensor:
74
- x = self.feature_extractor(x)
75
- x = self.avgpool(x)
76
- x = torch.flatten(x, start_dim=1, end_dim=-1)
77
- x = self.head(x)
78
-
79
- return x
80
-
81
- def init_weights_(self, fn: Callable[[nn.Module], None]) -> Self:
82
- self.apply(fn)
83
-
84
- return self
85
-
86
-
87
- @torch.no_grad()
88
- def init_weights(m: nn.Module):
89
- """Initialize weight for the module."""
90
- try:
91
- if isinstance(m, nn.Linear):
92
- init.xavier_uniform_(m.weight)
93
- if m.bias is not None:
94
- init.zeros_(m.bias)
95
- elif isinstance(m, nn.Conv2d):
96
- init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
97
- if m.bias is not None:
98
- init.zeros_(m.bias)
99
- except Exception as e:
100
- logger.error(f"Error initializing weight module {m}.")
101
-
102
-
103
- def build_alexnet(*, weight_path: Optional[str] = None, **model_kwargs: Any) -> AlexNet:
104
- """AlexNet model architecture from `One weird trick for parallelizing convolutional neural netwroks <https://arxiv.org/abs/1404.5997>`__.
105
-
106
- .. note::
107
- AlexNet was originally introduced in the `Image Classification` with Deep Convolutional Neural Network
108
- <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
109
- paper. Our implementation is based instead on the `One weird trick` paper above.
110
-
111
- Args:
112
- weight_path (str): path to saved weights.
113
- **model_kwargs: parameters passed to model.
114
- """
115
-
116
- model = AlexNet(**model_kwargs)
117
-
118
- model.init_weights_(init_weights)
119
- logger.info("Initialized random weights.")
120
-
121
- if weight_path:
122
- try:
123
- state_dict = torch.load(weight_path)
124
- model.load_state_dict(state_dict=state_dict, strict=False, assign=True)
125
- logger.info(f"Loaded state dict from {weight_path!r}.")
126
- except Exception as e:
127
- logger.error(f"Error loading state dict: {e}")
128
-
129
- return model