File size: 1,220 Bytes
6733993
 
 
 
 
2eebda5
6733993
 
 
 
 
 
 
 
0613d9c
 
6733993
 
 
2eebda5
 
6733993
2eebda5
 
6733993
0613d9c
6733993
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from MultiplicationNet import MultiplicationNet
from device import device

def generate_data(num_samples, min_val=0, max_val=100):
    x1 = np.random.randint(min_val, max_val, size=(num_samples, 1))
    x2 = np.random.randint(min_val, max_val, size=(num_samples, 1))
    y = x1 * x2
    return np.hstack([x1, x2]), y

def train():
    num_samples = 10000
    num_epochs = 30000
    learning_rate = 0.01

    x, y = generate_data(num_samples)
    x_train = torch.tensor(x, dtype=torch.float).to(device)
    y_train = torch.tensor(y, dtype=torch.float).to(device)

    model = MultiplicationNet().to(device)
    criterion = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)

    for epoch in range(num_epochs):
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f"Epoch {epoch}, loss = {loss.item()}")

    torch.save(model, "model.pth")

if __name__ == '__main__':
    train()