Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import matplotlib.pyplot as plt | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Streamlit interface | |
| st.title("CNN for Image Classification using CIFAR-10") | |
| # Hyperparameters | |
| num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10) | |
| batch_size = st.sidebar.slider("Batch size", 10, 200, 100, step=10) | |
| learning_rate = st.sidebar.slider("Learning rate", 0.0001, 0.01, 0.001, step=0.0001) | |
| # CIFAR-10 dataset | |
| transform = transforms.Compose( | |
| [transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
| train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
| download=True, transform=transform) | |
| test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
| download=True, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| # Define a Convolutional Neural Network | |
| class CNN(nn.Module): | |
| def __init__(self): | |
| super(CNN, self).__init__() | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2)) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2)) | |
| # Automatically determine the size of the flattened features after convolution and pooling | |
| self._to_linear = None | |
| self.convs(torch.randn(1, 3, 32, 32)) | |
| self.fc1 = nn.Linear(self._to_linear, 600) | |
| self.drop = nn.Dropout2d(0.25) | |
| self.fc2 = nn.Linear(600, 100) | |
| self.fc3 = nn.Linear(100, 10) | |
| def convs(self, x): | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| if self._to_linear is None: | |
| self._to_linear = x.view(x.size(0), -1).shape[1] | |
| return x | |
| def forward(self, x): | |
| x = self.convs(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc1(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.fc3(x) | |
| return x | |
| model = CNN().to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| # Button to start training | |
| if st.button("Start Training"): | |
| # Lists to store losses | |
| train_losses = [] | |
| test_losses = [] | |
| # Train the model | |
| total_step = len(train_loader) | |
| for epoch in range(num_epochs): | |
| train_loss = 0 | |
| for i, (images, labels) in enumerate(train_loader): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| # Forward pass | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| # Backward and optimize | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_loss /= total_step | |
| train_losses.append(train_loss) | |
| st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}') | |
| # Test the model | |
| model.eval() | |
| with torch.no_grad(): | |
| test_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in test_loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| test_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| test_loss /= len(test_loader) | |
| test_losses.append(test_loss) | |
| accuracy = 100 * correct / total | |
| st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%') | |
| model.train() | |
| # Plotting the loss | |
| fig, ax = plt.subplots() | |
| ax.plot(range(1, num_epochs + 1), train_losses, label='Train Loss') | |
| ax.plot(range(1, num_epochs + 1), test_losses, label='Test Loss') | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel('Loss') | |
| ax.set_title('Training and Test Loss') | |
| ax.legend() | |
| st.pyplot(fig) | |
| # Save the model checkpoint | |
| torch.save(model.state_dict(), 'cnn_model.pth') | |
| st.write("Model training completed and saved as 'cnn_model.pth'") | |