photo-enhancer / src /sac /sac_inference.py
zakaria-narjis's picture
Update dependencies
eee716c
from .sac_networks import Actor, SoftQNetwork, ResNETBackbone, SemanticBackbone, SemanticBackboneOC
import torch
from .utils import *
class InferenceAgent:
def __init__(self,inference_env, inference_args):
self.args =inference_args
self.device = inference_args.device
if inference_env.use_txt_features=='embedded':
self.backbone = SemanticBackbone().to(self.device)
elif inference_env.use_txt_features=='one_hot':
self.backbone = SemanticBackboneOC().to(self.device)
else:
self.backbone = ResNETBackbone().to(self.device)
self.env = inference_env
self.discretize = self.env.discretize
self.discretize_step = self.env.discretize_step
def discretize_actions(self,batch_actions):
return torch.round((batch_actions+1)/self.discretize_step)*self.discretize_step-1
def load_backbone (self,backbone_path):
self.backbone.load_state_dict(torch.load(backbone_path, map_location=self.device, weights_only=True))
self.actor = Actor(self.env,self.backbone).to(self.device)
self.qf1 = SoftQNetwork(self.env,self.backbone).to(self.device)
self.qf2 = SoftQNetwork(self.env,self.backbone).to(self.device)
self.backbone.eval().requires_grad_(False)
def load_actor_weights(self,actor_path):
load_actor_head(self.actor, actor_path,self.device)
self.actor.eval()
def load_critics_weights(self,qf1_path,qf2_path):
load_critic_head(self.qf1, qf1_path,self.device)
load_critic_head(self.qf2, qf2_path,self.device)
self.qf1.eval().requires_grad_(False)
self.qf2.eval().requires_grad_(False)
def act(self,obs,deterministic=True,n_samples=None):
if n_samples ==None:
n_samples = self.args.n_actions_samples
best_actions = None
with torch.inference_mode():
if deterministic:
actions = self.actor.get_action(**obs.to(self.device))[2]#mean action
best_actions = actions
else:
best_actions= self.actor.get_action(**obs.to(self.device))[0] #mean action
best_values = self.critic(obs,best_actions).view(-1)
for sample in range(n_samples):
if sample==0:
actions = self.actor.get_action(**obs.to(self.device))[2]#sampled action
values = self.critic(obs,actions).view(-1)
else:
actions = self.actor.get_action(**obs.to(self.device))[0]#sampled action
values = self.critic(obs,actions).view(-1)
if (values> best_values).any():
best_actions[values>best_values] = actions[values>best_values]
best_values[values>best_values] = values[values>best_values]
if self.discretize:
best_actions = self.discretize_actions(best_actions)
return best_actions
def critic(self,obs,actions):
with torch.inference_mode():
qf1_pi = self.qf1(**obs.to(self.device), actions=actions.to(self.device))
qf2_pi = self.qf2(**obs.to(self.device), actions=actions.to(self.device))
value = torch.min(qf1_pi, qf2_pi)
return value