import pandas as pd import argparse import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from PIL import Image class ImagePathDataset(Dataset): def __init__(self, csv_path): self.base_dir = os.path.dirname(csv_path) df = pd.read_csv(csv_path) self.paths = df["path"].tolist() self.labels = df["label"].astype(int).tolist() self.transform = transforms.Compose( [ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), ] ) def __len__(self): return len(self.paths) def __getitem__(self, idx): rel_path = self.paths[idx] full_path = os.path.join(self.base_dir, rel_path) img = Image.open(full_path).convert("L") img = self.transform(img) label = self.labels[idx] return img, label, rel_path class ModelCNN(nn.Module): """ Архитектура: INPUT (1x28x28) -> [CONV -> RELU -> CONV -> RELU -> POOL] * 3 -> [FC -> RELU] * 2 -> FC (10) """ def __init__(self): super(ModelCNN, self).__init__() self.features = nn.Sequential( # блок 1 nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 28 -> 14 # блок 2 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 14 -> 7 # блок 3 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 7 -> 3 ) self.classifier = nn.Sequential( nn.Linear(128 * 3 * 3, 256), nn.ReLU(inplace=True), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Linear(128, 10), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) # (N, 128*3*3) x = self.classifier(x) return x def train_mode(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Устройство: {device}") dataset = ImagePathDataset(args.dataset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) model = ModelCNN().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) model.train() for epoch in range(args.epochs): for i, (images, labels, _) in enumerate(dataloader): images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 100 == 0: print( f"Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}" ) torch.save(model.state_dict(), args.model) def inference_mode(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Устройство: {device}") state_dict = torch.load(args.model, map_location=device) model = ModelCNN().to(device) model.load_state_dict(state_dict) model.eval() dataset = ImagePathDataset(args.input) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) all_preds = [] all_paths = [] with torch.no_grad(): for images, _, pathes in dataloader: images = images.to(device) outputs = model(images) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy().tolist()) all_paths.extend(pathes) df_pred = pd.DataFrame( { "path": all_paths, "label": all_preds, } ) df_pred.to_csv(args.output, index=False) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--mode", choices=["train", "inference"], required=True) parser.add_argument("--dataset", type=str) parser.add_argument("--input", type=str) parser.add_argument("--output", type=str) parser.add_argument("--model", type=str, required=True) parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--lr", type=float, default=0.001) args = parser.parse_args() if args.mode == "train": if args.dataset is None: parser.error("--dataset обязателен в режиме train") elif args.mode == "inference": if args.input is None or args.output is None: parser.error("--input и --output обязательны в режиме inference") return args def main(): args = parse_args() if args.mode == "train": train_mode(args) else: inference_mode(args) if __name__ == "__main__": main()