Spaces:
Sleeping
Sleeping
| import torch | |
| def save_actor_head(actor_model, file_path): | |
| # Extract the state dictionaries of the relevant layers | |
| actor_head_state_dict = { | |
| 'fc1': actor_model.fc1.state_dict(), | |
| 'fc2': actor_model.fc2.state_dict(), | |
| 'fc_mean': actor_model.fc_mean.state_dict(), | |
| 'fc_logstd': actor_model.fc_logstd.state_dict() | |
| } | |
| # Save the state dictionaries to a file | |
| torch.save(actor_head_state_dict, file_path) | |
| def load_actor_head(actor_model, file_path,device): | |
| # Load the state dictionaries from the file | |
| actor_head_state_dict = torch.load(file_path, map_location=device,weights_only=True) | |
| # Load the state dictionaries into the model | |
| actor_model.fc1.load_state_dict(actor_head_state_dict['fc1']) | |
| actor_model.fc2.load_state_dict(actor_head_state_dict['fc2']) | |
| actor_model.fc_mean.load_state_dict(actor_head_state_dict['fc_mean']) | |
| actor_model.fc_logstd.load_state_dict(actor_head_state_dict['fc_logstd']) | |
| def save_critic_head(critic_model, file_path): | |
| # Extract the state dictionaries of the relevant layers | |
| critic_head_state_dict = { | |
| 'fc1': critic_model.fc1.state_dict(), | |
| 'fc2': critic_model.fc2.state_dict(), | |
| 'fc3': critic_model.fc3.state_dict() | |
| } | |
| # Save the state dictionaries to a file | |
| torch.save(critic_head_state_dict, file_path) | |
| def load_critic_head(critic_model, file_path,device): | |
| # Load the state dictionaries from the file | |
| critic_head_state_dict = torch.load(file_path, map_location=device,weights_only=True) | |
| # Load the state dictionaries into the model | |
| critic_model.fc1.load_state_dict(critic_head_state_dict['fc1']) | |
| critic_model.fc2.load_state_dict(critic_head_state_dict['fc2']) | |
| critic_model.fc3.load_state_dict(critic_head_state_dict['fc3']) |