zakaria-narjis's picture
Update dependencies
eee716c
import torch
def save_actor_head(actor_model, file_path):
# Extract the state dictionaries of the relevant layers
actor_head_state_dict = {
'fc1': actor_model.fc1.state_dict(),
'fc2': actor_model.fc2.state_dict(),
'fc_mean': actor_model.fc_mean.state_dict(),
'fc_logstd': actor_model.fc_logstd.state_dict()
}
# Save the state dictionaries to a file
torch.save(actor_head_state_dict, file_path)
def load_actor_head(actor_model, file_path,device):
# Load the state dictionaries from the file
actor_head_state_dict = torch.load(file_path, map_location=device,weights_only=True)
# Load the state dictionaries into the model
actor_model.fc1.load_state_dict(actor_head_state_dict['fc1'])
actor_model.fc2.load_state_dict(actor_head_state_dict['fc2'])
actor_model.fc_mean.load_state_dict(actor_head_state_dict['fc_mean'])
actor_model.fc_logstd.load_state_dict(actor_head_state_dict['fc_logstd'])
def save_critic_head(critic_model, file_path):
# Extract the state dictionaries of the relevant layers
critic_head_state_dict = {
'fc1': critic_model.fc1.state_dict(),
'fc2': critic_model.fc2.state_dict(),
'fc3': critic_model.fc3.state_dict()
}
# Save the state dictionaries to a file
torch.save(critic_head_state_dict, file_path)
def load_critic_head(critic_model, file_path,device):
# Load the state dictionaries from the file
critic_head_state_dict = torch.load(file_path, map_location=device,weights_only=True)
# Load the state dictionaries into the model
critic_model.fc1.load_state_dict(critic_head_state_dict['fc1'])
critic_model.fc2.load_state_dict(critic_head_state_dict['fc2'])
critic_model.fc3.load_state_dict(critic_head_state_dict['fc3'])