Spaces:
Build error
Build error
File size: 2,046 Bytes
012f016 d6783ca d5de549 d4c8be9 3a1f1e7 d4c8be9 012f016 d6783ca 012f016 d6783ca 6635d9e d5de549 d6783ca 012f016 d6783ca 17b871f d6783ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import gradio as gr
import spaces
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from huggingface_hub import notebook_login
from huggingface_hub import HfFolder
from huggingface_hub import Repository
HfFolder.save_token("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")
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(784, 10) # Simple model for MNIST
def forward(self, x):
x = x.view(-1, 784) # Flatten the image
x = self.fc(x)
return x
@spaces.GPU
def train_model(epochs):
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# Model, loss, and optimizer
model = SimpleNet().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(epochs):
model.train()
running_loss = 0.0
for data, target in train_loader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}, Average Loss: {running_loss / len(train_loader)}")
# Save the model checkpoint
torch.save(model.state_dict(), "simple_net.pth")
repo = Repository(local_dir="/")
repo.git_add(file_paths="simple_net.pth")
repo.git_commit("model checkpoint")
repo.git_push()
return "Training completed and model saved."
# Define the Gradio interface
demo = gr.Interface(fn=train_model, inputs=gr.Slider(2, 1000, value=4, label="Count", info="Choose epochs count"), outputs="text")
demo.launch()
|