Spaces:
Build error
Build error
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import datasets, transforms | |
| from torch.utils.data import DataLoader | |
| # from huggingface_hub import HfFolder | |
| # from huggingface_hub import Repository | |
| # HfFolder.save_token("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") | |
| class SimpleNet(nn.Module): | |
| def __init__(self): | |
| super(SimpleNet, self).__init__() | |
| self.fc = nn.Linear(784, 10) # Simple model for MNIST | |
| def forward(self, x): | |
| x = x.view(-1, 784) # Flatten the image | |
| x = self.fc(x) | |
| return x | |
| def train_model(epochs): | |
| # Load MNIST dataset | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) | |
| # Model, loss, and optimizer | |
| model = SimpleNet().cuda() | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| # Training loop | |
| for epoch in range(epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| for data, target in train_loader: | |
| data, target = data.cuda(), target.cuda() | |
| optimizer.zero_grad() | |
| output = model(data) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| print(f"Epoch {epoch + 1}, Average Loss: {running_loss / len(train_loader)}") | |
| # Save the model checkpoint | |
| torch.save(model.state_dict(), "simple_net.pth") | |
| return "Training completed and model saved." | |
| # Define the Gradio interface | |
| demo = gr.Interface(fn=train_model, inputs=gr.Slider(2, 1000, value=4, label="Count", info="Choose epochs count"), outputs="text") | |
| demo.launch() | |