File size: 1,021 Bytes
60fd15e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImprovedNet(nn.Module):
def __init__(self, input_features, output_features, dropout=0.30):
super().__init__()
self.layer1 = nn.Linear(input_features, 512)
self.layer2 = nn.Linear(512, 256)
self.layer3 = nn.Linear(256, 128)
self.output_layer = nn.Linear(128, output_features)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = F.relu(self.layer1(x)); x = self.dropout(x)
x = F.relu(self.layer2(x)); x = self.dropout(x)
x = F.relu(self.layer3(x))
return self.output_layer(x)
def load_model(weights_path, config_path="config.json", device="cpu"):
import json
with open(config_path) as f:
cfg = json.load(f)
model = ImprovedNet(cfg["input_size"], cfg["output_size"], dropout=cfg.get("dropout", 0.30))
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
return model
|