File size: 717 Bytes
a329468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    """
    Meta-Controller using DQN (Deep Q-Network).
    Input: [Volatility(1), Market_Regime(3 - OneHot), Global_PnL_Trend(1)] -> 5 Dim
    Output: Q-Values for Actions (3)
       0: FollowTrend Agent
       1: MeanReversion Agent
       2: Defensive Mode (Cash)
    """
    def __init__(self, input_dim=5, output_dim=3):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.Linear(64, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)