Thermoforming / data_augment.py
Rui Wan
Remove unused model
38a1679
raw
history blame contribute delete
895 Bytes
import numpy as np
import torch
from Dataset import Dataset
from model import NeuralNetwork
from model_inverse import DEVICE
def load_model(model_path):
checkpoint = torch.load(model_path, map_location=DEVICE)
model_config = checkpoint['model_config']
model = NeuralNetwork(model_config['layer_sizes'], dropout_rate=model_config['dropout_rate'])
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded from {model_path}")
model.to(DEVICE)
model.eval()
return model
def augment_data():
data = Dataset()
model = load_model('./model_ckpt.pth')
ply_number_bounds = [2., 8.]
initial_temp_bounds = [350., 450.]
punch_velocity_bounds = [100., 500.]
cooling_time_bounds = [450., 550.]
bounds = torch.tensor([ply_number_bounds, initial_temp_bounds, punch_velocity_bounds, cooling_time_bounds], dtype=torch.float32)