| | import numpy as np |
| |
|
| | class QLearningAgent: |
| | def __init__(self, lr=0.1, gamma=0.95, epsilon=0.1): |
| | self.lr = lr |
| | self.gamma = gamma |
| | self.epsilon = epsilon |
| | self.initial_epsilon = epsilon |
| | self.epsilon_decay = 0.99 |
| | self.q_table = {} |
| |
|
| | def get_state(self, val_loss, current_lr): |
| | return (round(val_loss.item(), 2), round(current_lr, 5)) |
| |
|
| | def choose_action(self, state): |
| | if np.random.rand() < self.epsilon: |
| | self.epsilon *= self.epsilon_decay |
| | return np.random.choice([-1, 0, 1]) |
| | if state not in self.q_table: |
| | self.q_table[state] = [0, 0, 0] |
| | return np.argmax(self.q_table[state]) - 1 |
| |
|
| | def update_q_values(self, state, action, reward, next_state): |
| | if state not in self.q_table: |
| | self.q_table[state] = [0, 0, 0] |
| | if next_state not in self.q_table: |
| | self.q_table[next_state] = [0, 0, 0] |
| | best_next_action = np.argmax(self.q_table[next_state]) |
| | td_target = reward + self.gamma * self.q_table[next_state][best_next_action] |
| | td_error = td_target - self.q_table[state][action + 1] |
| | self.q_table[state][action + 1] += self.lr * td_error |
| |
|