Shoker2 commited on
Commit
e3963a4
·
1 Parent(s): 242c54b

refactor: переделана логика

Browse files
Files changed (3) hide show
  1. .gitignore +1 -1
  2. dataset_downloader.py +50 -0
  3. mnist.py +63 -49
.gitignore CHANGED
@@ -16,6 +16,6 @@ uv.lock
16
  .vscode
17
  test.drawio
18
  test.py
19
- models/*
20
  Fashion-MNIST/
21
  MNIST/
 
16
  .vscode
17
  test.drawio
18
  test.py
19
+ models/
20
  Fashion-MNIST/
21
  MNIST/
dataset_downloader.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ import pandas as pd
5
+ from datasets import load_dataset
6
+
7
+
8
+ def export_mnist_splits(root_dir: str, dataset: str):
9
+ ds = load_dataset(dataset)
10
+
11
+ img_dir = os.path.join(root_dir, "img")
12
+ os.makedirs(img_dir, exist_ok=True)
13
+
14
+ def save_split(split_name: str):
15
+ if split_name not in ds:
16
+ print(f"Сплит '{split_name}' не найден в датасете {dataset}, пропускаю")
17
+ return
18
+
19
+ split = ds[split_name]
20
+ rows = []
21
+
22
+ for idx, example in enumerate(split):
23
+ img = example["image"]
24
+ label = example["label"]
25
+
26
+ filename = f"{split_name}_{idx:05d}.png"
27
+ rel_path = f"img/{filename}"
28
+ abs_path = os.path.join(img_dir, filename)
29
+
30
+ img.save(abs_path)
31
+ rows.append({"path": rel_path, "label": label})
32
+
33
+ csv_path = os.path.join(root_dir, f"{split_name}.csv")
34
+ df = pd.DataFrame(rows)
35
+ df.to_csv(csv_path, index=False)
36
+
37
+ print(f"{split_name}.csv сохранён в {csv_path}, изображений: {len(split)}")
38
+
39
+ save_split("train")
40
+ save_split("test")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("-f", "--folder", type=str, required=True)
46
+ parser.add_argument("-d", "--dataset", type=str, required=True)
47
+
48
+ args = parser.parse_args()
49
+
50
+ export_mnist_splits(args.folder, args.dataset)
mnist.py CHANGED
@@ -1,12 +1,44 @@
1
- import argparse
2
-
3
- import numpy as np
4
  import pandas as pd
 
5
 
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.optim as optim
9
- from torch.utils.data import DataLoader, TensorDataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class ModelCNN(nn.Module):
@@ -15,10 +47,10 @@ class ModelCNN(nn.Module):
15
  INPUT (1x28x28) ->
16
  [CONV -> RELU -> CONV -> RELU -> POOL] * 3 ->
17
  [FC -> RELU] * 2 ->
18
- FC (num_classes)
19
  """
20
 
21
- def __init__(self, num_classes=10):
22
  super(ModelCNN, self).__init__()
23
 
24
  self.features = nn.Sequential(
@@ -47,7 +79,7 @@ class ModelCNN(nn.Module):
47
  nn.ReLU(inplace=True),
48
  nn.Linear(256, 128),
49
  nn.ReLU(inplace=True),
50
- nn.Linear(128, num_classes),
51
  )
52
 
53
  def forward(self, x):
@@ -61,24 +93,16 @@ def train_mode(args):
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
  print(f"Устройство: {device}")
63
 
64
- df = pd.read_csv(args.input)
65
-
66
- labels = df["label"].astype(np.int64).values
67
- pixels = df.drop(columns=["label"]).values.astype(np.float32) / 255.0
68
-
69
- images = torch.from_numpy(pixels.reshape(-1, 1, 28, 28))
70
- labels = torch.from_numpy(labels)
71
-
72
- dataset = TensorDataset(images, labels)
73
  dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
74
 
75
- model = ModelCNN(num_classes=args.num_classes).to(device)
76
  criterion = nn.CrossEntropyLoss()
77
  optimizer = optim.Adam(model.parameters(), lr=args.lr)
78
 
79
  model.train()
80
  for epoch in range(args.epochs):
81
- for i, (images, labels) in enumerate(dataloader):
82
  images = images.to(device)
83
  labels = labels.to(device)
84
 
@@ -94,50 +118,40 @@ def train_mode(args):
94
  f"Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}"
95
  )
96
 
97
- checkpoint = {
98
- "state_dict": model.state_dict(),
99
- "num_classes": args.num_classes,
100
- }
101
- torch.save(checkpoint, args.model)
102
 
103
 
104
  def inference_mode(args):
105
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
  print(f"Устройство: {device}")
107
 
108
- checkpoint = torch.load(args.model, map_location=device)
109
- num_classes = checkpoint.get("num_classes", 10)
110
 
111
- model = ModelCNN(num_classes=num_classes).to(device)
112
- model.load_state_dict(checkpoint["state_dict"])
113
  model.eval()
114
 
115
- df_test = pd.read_csv(args.input)
116
-
117
- has_label = "label" in df_test.columns
118
- if has_label:
119
- pixels = df_test.drop(columns=["label"]).values
120
- else:
121
- pixels = df_test.values
122
-
123
- pixels = pixels.astype(np.float32) / 255.0
124
- images = torch.from_numpy(pixels.reshape(-1, 1, 28, 28))
125
-
126
- dataset = TensorDataset(images)
127
  dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
128
 
129
  all_preds = []
 
130
 
131
  with torch.no_grad():
132
- for (batch_images,) in dataloader:
133
- batch_images = batch_images.to(device)
134
- outputs = model(batch_images)
135
  _, preds = torch.max(outputs, 1)
136
- all_preds.extend(preds.cpu().numpy().tolist())
137
-
138
- df_pred = df_test.copy()
139
- df_pred["label"] = all_preds
140
 
 
 
 
 
 
 
 
 
 
141
  df_pred.to_csv(args.output, index=False)
142
 
143
 
@@ -145,6 +159,7 @@ def parse_args():
145
  parser = argparse.ArgumentParser()
146
 
147
  parser.add_argument("--mode", choices=["train", "inference"], required=True)
 
148
  parser.add_argument("--input", type=str)
149
  parser.add_argument("--output", type=str)
150
  parser.add_argument("--model", type=str, required=True)
@@ -152,13 +167,12 @@ def parse_args():
152
  parser.add_argument("--epochs", type=int, default=5)
153
  parser.add_argument("--batch-size", type=int, default=64)
154
  parser.add_argument("--lr", type=float, default=0.001)
155
- parser.add_argument("--num-classes", type=int, default=10)
156
 
157
  args = parser.parse_args()
158
 
159
  if args.mode == "train":
160
- if args.input is None:
161
- parser.error("--input обязателен в режиме train")
162
  elif args.mode == "inference":
163
  if args.input is None or args.output is None:
164
  parser.error("--input и --output обязательны в режиме inference")
 
 
 
 
1
  import pandas as pd
2
+ import argparse
3
 
4
+ import os
5
  import torch
6
  import torch.nn as nn
7
  import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import torchvision.transforms as transforms
10
+ from PIL import Image
11
+
12
+
13
+ class ImagePathDataset(Dataset):
14
+ def __init__(self, csv_path):
15
+ self.base_dir = os.path.dirname(csv_path)
16
+
17
+ df = pd.read_csv(csv_path)
18
+
19
+ self.paths = df["path"].tolist()
20
+ self.labels = df["label"].astype(int).tolist()
21
+
22
+ self.transform = transforms.Compose(
23
+ [
24
+ transforms.Grayscale(num_output_channels=1),
25
+ transforms.Resize((28, 28)),
26
+ transforms.ToTensor(),
27
+ ]
28
+ )
29
+
30
+ def __len__(self):
31
+ return len(self.paths)
32
+
33
+ def __getitem__(self, idx):
34
+ rel_path = self.paths[idx]
35
+ full_path = os.path.join(self.base_dir, rel_path)
36
+
37
+ img = Image.open(full_path).convert("L")
38
+ img = self.transform(img)
39
+
40
+ label = self.labels[idx]
41
+ return img, label, rel_path
42
 
43
 
44
  class ModelCNN(nn.Module):
 
47
  INPUT (1x28x28) ->
48
  [CONV -> RELU -> CONV -> RELU -> POOL] * 3 ->
49
  [FC -> RELU] * 2 ->
50
+ FC (10)
51
  """
52
 
53
+ def __init__(self):
54
  super(ModelCNN, self).__init__()
55
 
56
  self.features = nn.Sequential(
 
79
  nn.ReLU(inplace=True),
80
  nn.Linear(256, 128),
81
  nn.ReLU(inplace=True),
82
+ nn.Linear(128, 10),
83
  )
84
 
85
  def forward(self, x):
 
93
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
  print(f"Устройство: {device}")
95
 
96
+ dataset = ImagePathDataset(args.dataset)
 
 
 
 
 
 
 
 
97
  dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
98
 
99
+ model = ModelCNN().to(device)
100
  criterion = nn.CrossEntropyLoss()
101
  optimizer = optim.Adam(model.parameters(), lr=args.lr)
102
 
103
  model.train()
104
  for epoch in range(args.epochs):
105
+ for i, (images, labels, _) in enumerate(dataloader):
106
  images = images.to(device)
107
  labels = labels.to(device)
108
 
 
118
  f"Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}"
119
  )
120
 
121
+ torch.save(model.state_dict(), args.model)
 
 
 
 
122
 
123
 
124
  def inference_mode(args):
125
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
  print(f"Устройство: {device}")
127
 
128
+ state_dict = torch.load(args.model, map_location=device)
 
129
 
130
+ model = ModelCNN().to(device)
131
+ model.load_state_dict(state_dict)
132
  model.eval()
133
 
134
+ dataset = ImagePathDataset(args.input)
 
 
 
 
 
 
 
 
 
 
 
135
  dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
136
 
137
  all_preds = []
138
+ all_paths = []
139
 
140
  with torch.no_grad():
141
+ for images, _, pathes in dataloader:
142
+ images = images.to(device)
143
+ outputs = model(images)
144
  _, preds = torch.max(outputs, 1)
 
 
 
 
145
 
146
+ all_preds.extend(preds.cpu().numpy().tolist())
147
+ all_paths.extend(pathes)
148
+
149
+ df_pred = pd.DataFrame(
150
+ {
151
+ "path": all_paths,
152
+ "label": all_preds,
153
+ }
154
+ )
155
  df_pred.to_csv(args.output, index=False)
156
 
157
 
 
159
  parser = argparse.ArgumentParser()
160
 
161
  parser.add_argument("--mode", choices=["train", "inference"], required=True)
162
+ parser.add_argument("--dataset", type=str)
163
  parser.add_argument("--input", type=str)
164
  parser.add_argument("--output", type=str)
165
  parser.add_argument("--model", type=str, required=True)
 
167
  parser.add_argument("--epochs", type=int, default=5)
168
  parser.add_argument("--batch-size", type=int, default=64)
169
  parser.add_argument("--lr", type=float, default=0.001)
 
170
 
171
  args = parser.parse_args()
172
 
173
  if args.mode == "train":
174
+ if args.dataset is None:
175
+ parser.error("--dataset обязателен в режиме train")
176
  elif args.mode == "inference":
177
  if args.input is None or args.output is None:
178
  parser.error("--input и --output обязательны в режиме inference")