laba2 / mnist.py
Shoker2
refactor: переделана логика
e3963a4
import pandas as pd
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
class ImagePathDataset(Dataset):
def __init__(self, csv_path):
self.base_dir = os.path.dirname(csv_path)
df = pd.read_csv(csv_path)
self.paths = df["path"].tolist()
self.labels = df["label"].astype(int).tolist()
self.transform = transforms.Compose(
[
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
rel_path = self.paths[idx]
full_path = os.path.join(self.base_dir, rel_path)
img = Image.open(full_path).convert("L")
img = self.transform(img)
label = self.labels[idx]
return img, label, rel_path
class ModelCNN(nn.Module):
"""
Архитектура:
INPUT (1x28x28) ->
[CONV -> RELU -> CONV -> RELU -> POOL] * 3 ->
[FC -> RELU] * 2 ->
FC (10)
"""
def __init__(self):
super(ModelCNN, self).__init__()
self.features = nn.Sequential(
# блок 1
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 28 -> 14
# блок 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 14 -> 7
# блок 3
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 7 -> 3
)
self.classifier = nn.Sequential(
nn.Linear(128 * 3 * 3, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 10),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # (N, 128*3*3)
x = self.classifier(x)
return x
def train_mode(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
dataset = ImagePathDataset(args.dataset)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
model = ModelCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
model.train()
for epoch in range(args.epochs):
for i, (images, labels, _) in enumerate(dataloader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print(
f"Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}"
)
torch.save(model.state_dict(), args.model)
def inference_mode(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
state_dict = torch.load(args.model, map_location=device)
model = ModelCNN().to(device)
model.load_state_dict(state_dict)
model.eval()
dataset = ImagePathDataset(args.input)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
all_preds = []
all_paths = []
with torch.no_grad():
for images, _, pathes in dataloader:
images = images.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy().tolist())
all_paths.extend(pathes)
df_pred = pd.DataFrame(
{
"path": all_paths,
"label": all_preds,
}
)
df_pred.to_csv(args.output, index=False)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["train", "inference"], required=True)
parser.add_argument("--dataset", type=str)
parser.add_argument("--input", type=str)
parser.add_argument("--output", type=str)
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--lr", type=float, default=0.001)
args = parser.parse_args()
if args.mode == "train":
if args.dataset is None:
parser.error("--dataset обязателен в режиме train")
elif args.mode == "inference":
if args.input is None or args.output is None:
parser.error("--input и --output обязательны в режиме inference")
return args
def main():
args = parse_args()
if args.mode == "train":
train_mode(args)
else:
inference_mode(args)
if __name__ == "__main__":
main()