File size: 1,220 Bytes
6733993 2eebda5 6733993 0613d9c 6733993 2eebda5 6733993 2eebda5 6733993 0613d9c 6733993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from MultiplicationNet import MultiplicationNet
from device import device
def generate_data(num_samples, min_val=0, max_val=100):
x1 = np.random.randint(min_val, max_val, size=(num_samples, 1))
x2 = np.random.randint(min_val, max_val, size=(num_samples, 1))
y = x1 * x2
return np.hstack([x1, x2]), y
def train():
num_samples = 10000
num_epochs = 30000
learning_rate = 0.01
x, y = generate_data(num_samples)
x_train = torch.tensor(x, dtype=torch.float).to(device)
y_train = torch.tensor(y, dtype=torch.float).to(device)
model = MultiplicationNet().to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
for epoch in range(num_epochs):
outputs = model(x_train)
loss = criterion(outputs, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print(f"Epoch {epoch}, loss = {loss.item()}")
torch.save(model, "model.pth")
if __name__ == '__main__':
train()
|