File size: 1,255 Bytes
f71c233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np

class QLearningAgent:
    def __init__(self, lr=0.1, gamma=0.95, epsilon=0.1):
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.initial_epsilon = epsilon
        self.epsilon_decay = 0.99  # Decay rate for epsilon
        self.q_table = {}

    def get_state(self, val_loss, current_lr):
        return (round(val_loss.item(), 2), round(current_lr, 5))

    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            self.epsilon *= self.epsilon_decay  # Decay epsilon
            return np.random.choice([-1, 0, 1])
        if state not in self.q_table:
            self.q_table[state] = [0, 0, 0]
        return np.argmax(self.q_table[state]) - 1

    def update_q_values(self, state, action, reward, next_state):
        if state not in self.q_table:
            self.q_table[state] = [0, 0, 0]
        if next_state not in self.q_table:
            self.q_table[next_state] = [0, 0, 0]
        best_next_action = np.argmax(self.q_table[next_state])
        td_target = reward + self.gamma * self.q_table[next_state][best_next_action]
        td_error = td_target - self.q_table[state][action + 1]
        self.q_table[state][action + 1] += self.lr * td_error