HamzaNaser commited on
Commit
278639e
·
verified ·
1 Parent(s): 6e48ac8

Upload 8 files

Browse files
Files changed (7) hide show
  1. data_setup.py +54 -0
  2. engine.py +55 -0
  3. model.pth +3 -0
  4. model_builder.py +42 -0
  5. requirements.txt +145 -0
  6. train.py +93 -0
  7. utils.py +9 -0
data_setup.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import zipfile
4
+ from pathlib import Path
5
+ from torch.utils.data import DataLoader
6
+ from torchvision.datasets import ImageFolder
7
+ from torchvision.transforms import ToTensor,Compose, Resize,Normalize
8
+
9
+ num_workers = os.cpu_count()
10
+
11
+ def data_installing(ROOT_PATH, DATA_FILE_ID = '1yIhmdZRwcvyWOl92PygSVGSualOxiwjg'):
12
+
13
+ url = f'https://docs.google.com/uc?export=download&id={DATA_FILE_ID}'
14
+ output_file = 'data.zip'
15
+ output_path = ROOT_PATH / 'Data' / output_file
16
+
17
+ command = ['wget', '--no-check-certificate', url, '-O', output_path]
18
+
19
+ result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
20
+
21
+ with zipfile.ZipFile(output_path,'r') as zip:
22
+ zip.extractall(output_path.parent)
23
+
24
+ os.remove(output_path)
25
+ print('Data loaded..')
26
+
27
+
28
+
29
+ def data_loaders(ROOT_PATH,BATCH_SIZE, IMAGES_SIZE,P):
30
+
31
+ transform = Compose([
32
+ Resize(IMAGES_SIZE),
33
+ ToTensor(),
34
+ Normalize(mean=[0.485, 0.456, 0.406],
35
+ std=[0.229, 0.224, 0.225]),
36
+ ])
37
+
38
+ train_data = ImageFolder(ROOT_PATH / 'Data' / 'Training_data',
39
+ transform=transform)
40
+
41
+ test_data = ImageFolder(ROOT_PATH / 'Data' / 'Test_data',
42
+ transform=transform)
43
+
44
+ train_data_ = DataLoader(train_data,batch_size = BATCH_SIZE, shuffle=True, num_workers=num_workers)
45
+ test_data_ = DataLoader(test_data, batch_size=BATCH_SIZE,num_workers=num_workers)
46
+
47
+ class_names = train_data.classes
48
+
49
+
50
+ return train_data_,test_data_, class_names
51
+
52
+
53
+ if __name__=='__main__':
54
+ data_installing(Path('/home/hamza/Desktop/Study-Notes/Machine Learning/Pytourch/Modular'))
engine.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sklearn.metrics import accuracy_score
3
+
4
+
5
+
6
+
7
+ def train_step(
8
+ epoch,
9
+ model,
10
+ loss_fn,
11
+ optimizer,
12
+ train_data,
13
+ device):
14
+
15
+ train_loss, train_acc = 0,0
16
+ for batch, (X,y) in enumerate(train_data):
17
+ X,y = X.to(device), y.to(device)
18
+ model.train()
19
+ optimizer.zero_grad()
20
+
21
+ y_pred = model(X)
22
+ loss = loss_fn(y_pred,y)
23
+ train_loss+= loss
24
+ train_acc += accuracy_score(torch.softmax(y_pred,dim=1).argmax(axis=1).cpu(),y.cpu())
25
+
26
+ loss.backward()
27
+ optimizer.step()
28
+
29
+ train_loss /= len(train_data)
30
+ train_acc /= len(train_data)
31
+ print(f'Epoch {epoch} | train_Loss {train_loss:.2f} | train_acc {train_acc:.2f}')
32
+
33
+
34
+ def test_step(
35
+ epoch,
36
+ model,
37
+ loss_fn,
38
+ test_data,
39
+ device):
40
+
41
+
42
+
43
+ model.eval()
44
+ with torch.inference_mode():
45
+ test_loss, test_acc = 0,0
46
+ for _, (X,y) in enumerate(test_data):
47
+ X,y = X.to(device), y.to(device)
48
+ y_pred = model(X)
49
+ test_loss += loss_fn(y_pred,y)
50
+ test_acc += accuracy_score(torch.softmax(y_pred,dim=1).argmax(axis=1).cpu(),y.cpu())
51
+
52
+ test_loss /= len(test_data)
53
+ test_acc /= len(test_data)
54
+ print(f'Epoch {epoch} | test_loss {test_loss:.2f} | test_acc {test_acc:.2f}')
55
+ return test_acc
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2750d2c9f962d54f8290829e70ed20a65f21984a20413d40bf2c1597f9b1ef0d
3
+ size 16438794
model_builder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torchvision
3
+ from torch.nn.modules import Module, Sequential
4
+
5
+
6
+ class FullyDensed(Module):
7
+ def __init__(self,HIDDEN_UNITS):
8
+ super().__init__()
9
+
10
+ # weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
11
+ # model = torchvision.models.efficientnet_v2_s(weights=weights)
12
+
13
+ # for param in model.features.parameters():
14
+ # param.requires_grad = False
15
+ # model.classifier[1] = nn.Linear(1280,10)
16
+
17
+
18
+ weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
19
+ model = torchvision.models.efficientnet_b0(weights=weights)
20
+
21
+ for param in model.features.parameters():
22
+ param.requires_grad = False
23
+ model.classifier[1] = nn.Linear(1280,10)
24
+
25
+
26
+ self.seq = Sequential(
27
+ # nn.Conv2d(3,HIDDEN_UNITS,3),
28
+ # nn.Conv2d(HIDDEN_UNITS,HIDDEN_UNITS,3),
29
+ # nn.ReLU(),
30
+ # nn.MaxPool2d(2,2),
31
+ # nn.Conv2d(HIDDEN_UNITS,HIDDEN_UNITS,3),
32
+ # nn.Conv2d(HIDDEN_UNITS,50,3),
33
+ # nn.ReLU(),
34
+ # nn.MaxPool2d(2,2),
35
+ # nn.Flatten(),
36
+ # nn.Linear(800,10)
37
+ model,
38
+ )
39
+
40
+ def forward(self,x):
41
+ return self.seq(x)
42
+
requirements.txt ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiofiles==23.2.1
3
+ altair==5.3.0
4
+ annotated-types==0.7.0
5
+ anyio==4.4.0
6
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
7
+ astunparse==1.6.3
8
+ attrs==23.2.0
9
+ certifi==2024.2.2
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
13
+ contourpy==1.2.1
14
+ cycler==0.12.1
15
+ debugpy @ file:///croot/debugpy_1690905042057/work
16
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
17
+ dnspython==2.6.1
18
+ email_validator==2.1.1
19
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
20
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
21
+ fastapi==0.111.0
22
+ fastapi-cli==0.0.4
23
+ ffmpy==0.3.2
24
+ filelock==3.13.1
25
+ flatbuffers==24.3.25
26
+ fonttools==4.51.0
27
+ fsspec==2024.2.0
28
+ gast==0.5.4
29
+ google-pasta==0.2.0
30
+ gradio==4.36.0
31
+ gradio_client==1.0.1
32
+ grpcio==1.63.0
33
+ h11==0.14.0
34
+ h5py==3.11.0
35
+ httpcore==1.0.5
36
+ httptools==0.6.1
37
+ httpx==0.27.0
38
+ huggingface-hub==0.23.3
39
+ idna==3.7
40
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1710971335535/work
41
+ importlib_resources==6.4.0
42
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
43
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1715263367085/work
44
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
45
+ Jinja2==3.1.3
46
+ joblib==1.4.2
47
+ jsonschema==4.22.0
48
+ jsonschema-specifications==2023.12.1
49
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1710255804825/work
50
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257359434/work
51
+ keras==3.3.3
52
+ kiwisolver==1.4.5
53
+ libclang==18.1.1
54
+ Markdown==3.6
55
+ markdown-it-py==3.0.0
56
+ MarkupSafe==2.1.5
57
+ matplotlib==3.9.0
58
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
59
+ mdurl==0.1.2
60
+ ml-dtypes==0.3.2
61
+ mpmath==1.3.0
62
+ namex==0.0.8
63
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
64
+ networkx==3.2.1
65
+ numpy==1.26.3
66
+ nvidia-cublas-cu11==11.11.3.6
67
+ nvidia-cuda-cupti-cu11==11.8.87
68
+ nvidia-cuda-nvrtc-cu11==11.8.89
69
+ nvidia-cuda-runtime-cu11==11.8.89
70
+ nvidia-cudnn-cu11==8.7.0.84
71
+ nvidia-cufft-cu11==10.9.0.58
72
+ nvidia-curand-cu11==10.3.0.86
73
+ nvidia-cusolver-cu11==11.4.1.48
74
+ nvidia-cusparse-cu11==11.7.5.86
75
+ nvidia-nccl-cu11==2.20.5
76
+ nvidia-nvtx-cu11==11.8.86
77
+ opt-einsum==3.3.0
78
+ optree==0.11.0
79
+ orjson==3.10.3
80
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1710075952259/work
81
+ pandas==2.2.2
82
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
83
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
84
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
85
+ pillow==10.2.0
86
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1713912794367/work
87
+ prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
88
+ protobuf==4.25.3
89
+ psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1705722403006/work
90
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
91
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
92
+ pydantic==2.7.3
93
+ pydantic_core==2.18.4
94
+ pydub==0.25.1
95
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
96
+ pyparsing==3.1.2
97
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
98
+ python-dotenv==1.0.1
99
+ python-multipart==0.0.9
100
+ pytz==2024.1
101
+ PyYAML==6.0.1
102
+ pyzmq @ file:///croot/pyzmq_1705605076900/work
103
+ referencing==0.35.1
104
+ requests==2.31.0
105
+ rich==13.7.1
106
+ rpds-py==0.18.1
107
+ ruff==0.4.8
108
+ scikit-learn==1.4.2
109
+ scipy==1.13.0
110
+ semantic-version==2.10.0
111
+ shellingham==1.5.4
112
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
113
+ sniffio==1.3.1
114
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
115
+ starlette==0.37.2
116
+ sympy==1.12
117
+ tensorboard==2.16.2
118
+ tensorboard-data-server==0.7.2
119
+ tensorflow==2.16.1
120
+ tensorflow-io-gcs-filesystem==0.37.0
121
+ termcolor==2.4.0
122
+ threadpoolctl==3.5.0
123
+ tomlkit==0.12.0
124
+ toolz==0.12.1
125
+ torch==2.3.0+cu118
126
+ torch-summary==1.4.5
127
+ torchaudio==2.3.0+cu118
128
+ torchvision==0.18.0+cu118
129
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1708363099148/work
130
+ tqdm==4.66.4
131
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
132
+ triton==2.3.0
133
+ typer==0.12.3
134
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1712329955671/work
135
+ tzdata==2024.1
136
+ ujson==5.10.0
137
+ urllib3==2.2.1
138
+ uvicorn==0.30.1
139
+ uvloop==0.19.0
140
+ watchfiles==0.22.0
141
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
142
+ websockets==11.0.3
143
+ Werkzeug==3.0.3
144
+ wrapt==1.16.0
145
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
train.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ from model_builder import FullyDensed
6
+ from engine import test_step, train_step
7
+ from data_setup import data_loaders
8
+ from utils import save_model
9
+
10
+
11
+
12
+
13
+
14
+ def model_training(
15
+ P,
16
+ EPOCHS,
17
+ BATCH_SIZE,
18
+ HIDDEN_UNITS,
19
+ IMAGES_SIZE,
20
+ MODEL_NAME,
21
+ ROOT_PATH
22
+ ):
23
+
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ device = torch.device('cpu')
26
+
27
+ train_data, test_data, class_names = data_loaders(ROOT_PATH=ROOT_PATH,BATCH_SIZE=BATCH_SIZE,IMAGES_SIZE=IMAGES_SIZE,P=P)
28
+
29
+
30
+ model = FullyDensed(HIDDEN_UNITS)
31
+ model = model.to(device)
32
+
33
+ loss_fn = torch.nn.CrossEntropyLoss()
34
+ optimizer = torch.optim.Adam(model.parameters())
35
+
36
+ for epoch in range(EPOCHS):
37
+ train_step(
38
+ epoch,
39
+ model,
40
+ loss_fn,
41
+ optimizer,
42
+ train_data,
43
+ device
44
+ )
45
+
46
+ acc = test_step(
47
+ epoch,
48
+ model,
49
+ loss_fn,
50
+ test_data,
51
+ device
52
+ )
53
+
54
+
55
+ save_model(model,path = ROOT_PATH ,MODEL_NAME = MODEL_NAME + f'{int(HIDDEN_UNITS)}-units {int(acc*100//1)}%')
56
+
57
+
58
+
59
+ if __name__=='__main__':
60
+
61
+ parser = argparse.ArgumentParser(description='Train a model with specified parameters.')
62
+
63
+
64
+ parser.add_argument('--P', type=int, default=15)
65
+
66
+ parser.add_argument('--epochs', type=int, default=3)
67
+ parser.add_argument('--batch_size', type=int, default=32)
68
+ parser.add_argument('--hidden_units', type=int, default=30)
69
+ parser.add_argument('--images_size', type=int, nargs=2, default=[300,300])
70
+ parser.add_argument('--model_name', type=str, default='Eff NetB0')
71
+ parser.add_argument('--root_path', type=str, default='/home/hamza/Desktop/Study-Notes/Machine Learning/Pytourch/Modular')
72
+
73
+ args = parser.parse_args()
74
+
75
+
76
+ P = args.P
77
+ EPOCHS = args.epochs
78
+ BATCH_SIZE = args.batch_size
79
+ HIDDEN_UNITS = args.hidden_units
80
+ IMAGES_SIZE = args.images_size
81
+ MODEL_NAME = args.model_name
82
+ ROOT_PATH = Path(args.root_path)
83
+
84
+
85
+ model_training(
86
+ P,
87
+ EPOCHS,
88
+ BATCH_SIZE,
89
+ HIDDEN_UNITS,
90
+ IMAGES_SIZE,
91
+ MODEL_NAME,
92
+ ROOT_PATH
93
+ )
utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def save_model(model,path,MODEL_NAME):
5
+ MODEL_NAME = MODEL_NAME + '.pth'
6
+ SAVED_MODEL_PATH = path / 'Models' / MODEL_NAME
7
+ torch.save(model,f=SAVED_MODEL_PATH)
8
+
9
+