ehottl commited on
Commit
3732e13
·
1 Parent(s): 8da8fe5

add model file

Browse files
Files changed (3) hide show
  1. .gitignore +207 -0
  2. cifar_net.pth +3 -0
  3. pytorch_classifier_gen.py +82 -0
.gitignore ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data*
2
+
3
+ # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
4
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks
5
+
6
+ ### JupyterNotebooks ###
7
+ # gitignore template for Jupyter Notebooks
8
+ # website: http://jupyter.org/
9
+
10
+ .ipynb_checkpoints
11
+ */.ipynb_checkpoints/*
12
+
13
+ # IPython
14
+ profile_default/
15
+ ipython_config.py
16
+
17
+ # Remove previous ipynb_checkpoints
18
+ # git rm -r .ipynb_checkpoints/
19
+
20
+ ### Python ###
21
+ # Byte-compiled / optimized / DLL files
22
+ __pycache__/
23
+ *.py[cod]
24
+ *$py.class
25
+
26
+ # C extensions
27
+ *.so
28
+
29
+ # Distribution / packaging
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ share/python-wheels/
44
+ *.egg-info/
45
+ .installed.cfg
46
+ *.egg
47
+ MANIFEST
48
+
49
+ # PyInstaller
50
+ # Usually these files are written by a python script from a template
51
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
52
+ *.manifest
53
+ *.spec
54
+
55
+ # Installer logs
56
+ pip-log.txt
57
+ pip-delete-this-directory.txt
58
+
59
+ # Unit test / coverage reports
60
+ htmlcov/
61
+ .tox/
62
+ .nox/
63
+ .coverage
64
+ .coverage.*
65
+ .cache
66
+ nosetests.xml
67
+ coverage.xml
68
+ *.cover
69
+ *.py,cover
70
+ .hypothesis/
71
+ .pytest_cache/
72
+ cover/
73
+
74
+ # Translations
75
+ *.mo
76
+ *.pot
77
+
78
+ # Django stuff:
79
+ *.log
80
+ local_settings.py
81
+ db.sqlite3
82
+ db.sqlite3-journal
83
+
84
+ # Flask stuff:
85
+ instance/
86
+ .webassets-cache
87
+
88
+ # Scrapy stuff:
89
+ .scrapy
90
+
91
+ # Sphinx documentation
92
+ docs/_build/
93
+
94
+ # PyBuilder
95
+ .pybuilder/
96
+ target/
97
+
98
+ # Jupyter Notebook
99
+
100
+ # IPython
101
+
102
+ # pyenv
103
+ # For a library or package, you might want to ignore these files since the code is
104
+ # intended to run in multiple environments; otherwise, check them in:
105
+ # .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # poetry
115
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
117
+ # commonly ignored for libraries.
118
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119
+ #poetry.lock
120
+
121
+ # pdm
122
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
+ #pdm.lock
124
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
+ # in version control.
126
+ # https://pdm.fming.dev/#use-with-ide
127
+ .pdm.toml
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # SageMath parsed files
137
+ *.sage.py
138
+
139
+ # Environments
140
+ .env
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ # Spyder project settings
149
+ .spyderproject
150
+ .spyproject
151
+
152
+ # Rope project settings
153
+ .ropeproject
154
+
155
+ # mkdocs documentation
156
+ /site
157
+
158
+ # mypy
159
+ .mypy_cache/
160
+ .dmypy.json
161
+ dmypy.json
162
+
163
+ # Pyre type checker
164
+ .pyre/
165
+
166
+ # pytype static type analyzer
167
+ .pytype/
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+ # PyCharm
173
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
176
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
+ #.idea/
178
+
179
+ ### Python Patch ###
180
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
181
+ poetry.toml
182
+
183
+ # ruff
184
+ .ruff_cache/
185
+
186
+ # LSP config files
187
+ pyrightconfig.json
188
+
189
+ # Created by https://www.toptal.com/developers/gitignore/api/linux
190
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux
191
+
192
+ ### Linux ###
193
+ *~
194
+
195
+ # temporary files which can be created if a process still has a handle open of a deleted file
196
+ .fuse_hidden*
197
+
198
+ # KDE directory preferences
199
+ .directory
200
+
201
+ # Linux trash folder which might appear on any partition or disk
202
+ .Trash-*
203
+
204
+ # .nfs files are created when an open file is removed but is still being accessed
205
+ .nfs*
206
+
207
+ # End of https://www.toptal.com/developers/gitignore/api/linux
cifar_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf87afeab91ee70d644d0f22733711e058dfa3d20edb740b01cb1abe0ddd2b4e
3
+ size 251167
pytorch_classifier_gen.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+
8
+ # 데이터셋 불러오기
9
+ transform = transforms.Compose(
10
+ [transforms.ToTensor(),
11
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
12
+
13
+ batch_size = 4
14
+
15
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
16
+ download=True, transform=transform)
17
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
18
+ shuffle=True, num_workers=2)
19
+
20
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False,
21
+ download=True, transform=transform)
22
+ testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
23
+ shuffle=False, num_workers=2)
24
+
25
+ classes = ('plane', 'car', 'bird', 'cat',
26
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
27
+
28
+ # 합성곱 신경망 만들기
29
+ class Net(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.conv1 = nn.Conv2d(3, 6, 5)
33
+ self.pool = nn.MaxPool2d(2, 2)
34
+ self.conv2 = nn.Conv2d(6, 16, 5)
35
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
36
+ self.fc2 = nn.Linear(120, 84)
37
+ self.fc3 = nn.Linear(84, 10)
38
+
39
+ def forward(self, x):
40
+ x = self.pool(F.relu(self.conv1(x)))
41
+ x = self.pool(F.relu(self.conv2(x)))
42
+ x = torch.flatten(x, 1) # 배치를 제외한 모든 차원을 평탄화(flatten)
43
+ x = F.relu(self.fc1(x))
44
+ x = F.relu(self.fc2(x))
45
+ x = self.fc3(x)
46
+ return x
47
+
48
+
49
+ net = Net()
50
+
51
+ # 손실 함수와 오티마이져
52
+ criterion = nn.CrossEntropyLoss()
53
+ optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
54
+
55
+ # 학습하기
56
+ for epoch in range(2): # 데이터셋을 수차례 반복합니다.
57
+
58
+ running_loss = 0.0
59
+ for i, data in enumerate(trainloader, 0):
60
+ # [inputs, labels]의 목록인 data로부터 입력을 받은 후;
61
+ inputs, labels = data
62
+
63
+ # 변화도(Gradient) 매개변수를 0으로 만들고
64
+ optimizer.zero_grad()
65
+
66
+ # 순전파 + 역전파 + 최적화를 한 후
67
+ outputs = net(inputs)
68
+ loss = criterion(outputs, labels)
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ # 통계를 출력합니다.
73
+ running_loss += loss.item()
74
+ if i % 2000 == 1999: # print every 2000 mini-batches
75
+ print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
76
+ running_loss = 0.0
77
+
78
+ print('Finished Training')
79
+
80
+ # 모델 저장하기
81
+ PATH = './cifar_net.pth'
82
+ torch.save(net.state_dict(), PATH)