File size: 1,880 Bytes
012f016
 
 
d6783ca
 
 
 
e0a0101
 
 
012f016
d6783ca
 
 
 
 
 
 
 
 
012f016
2b9cd67
d6783ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
012f016
d6783ca
17b871f
d6783ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import spaces
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# from huggingface_hub import HfFolder
# from huggingface_hub import Repository
# HfFolder.save_token("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")

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(784, 10)  # Simple model for MNIST

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the image
        x = self.fc(x)
        return x

@spaces.GPU(duration=300)
def train_model(epochs):
    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Model, loss, and optimizer
    model = SimpleNet().cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for data, target in train_loader:
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Average Loss: {running_loss / len(train_loader)}")
    
    # Save the model checkpoint
    torch.save(model.state_dict(), "simple_net.pth")
    return "Training completed and model saved."

# Define the Gradio interface
demo = gr.Interface(fn=train_model, inputs=gr.Slider(2, 1000, value=4, label="Count", info="Choose epochs count"), outputs="text")
demo.launch()