DeepMostInnovations commited on
Commit
527c66c
·
verified ·
1 Parent(s): 2d58934

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -41,7 +41,7 @@ This is a reinforcement learning model trained to predict real-time sales conver
41
  - **Framework**: Stable Baselines3 (PPO)
42
  - **State Representation**: Azure OpenAI embeddings
43
  - **Action Space**: Continuous (conversion probability 0-1)
44
- - **Feature Extractor**: Custom CNN layers
45
 
46
  ## Quick Start
47
 
@@ -83,12 +83,12 @@ model_path = hf_hub_download(
83
  # Check for GPU
84
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
 
86
- # Custom CNN class
87
- class CustomCNN(BaseFeaturesExtractor):
88
  def __init__(self, observation_space, features_dim: int = 64):
89
  super().__init__(observation_space, features_dim)
90
  n_input_channels = observation_space.shape[0]
91
- self.cnn = nn.Sequential(
92
  nn.Linear(n_input_channels, 512),
93
  nn.ReLU(),
94
  nn.Linear(512, 256),
@@ -98,7 +98,7 @@ class CustomCNN(BaseFeaturesExtractor):
98
  ).to(device)
99
 
100
  def forward(self, observations: torch.Tensor) -> torch.Tensor:
101
- return self.cnn(observations)
102
 
103
  @dataclass
104
  class SalesAgent:
@@ -112,7 +112,7 @@ class SalesAgent:
112
  policy_kwargs = dict(
113
  activation_fn=nn.ReLU,
114
  net_arch=[dict(pi=[128, 64], vf=[128, 64])],
115
- features_extractor_class=CustomCNN,
116
  features_extractor_kwargs=dict(features_dim=64)
117
  )
118
 
 
41
  - **Framework**: Stable Baselines3 (PPO)
42
  - **State Representation**: Azure OpenAI embeddings
43
  - **Action Space**: Continuous (conversion probability 0-1)
44
+ - **Feature Extractor**: Custom Linear layers
45
 
46
  ## Quick Start
47
 
 
83
  # Check for GPU
84
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
 
86
+ # Custom Linear Layer class
87
+ class CustomLN(BaseFeaturesExtractor):
88
  def __init__(self, observation_space, features_dim: int = 64):
89
  super().__init__(observation_space, features_dim)
90
  n_input_channels = observation_space.shape[0]
91
+ self.ln = nn.Sequential(
92
  nn.Linear(n_input_channels, 512),
93
  nn.ReLU(),
94
  nn.Linear(512, 256),
 
98
  ).to(device)
99
 
100
  def forward(self, observations: torch.Tensor) -> torch.Tensor:
101
+ return self.ln(observations)
102
 
103
  @dataclass
104
  class SalesAgent:
 
112
  policy_kwargs = dict(
113
  activation_fn=nn.ReLU,
114
  net_arch=[dict(pi=[128, 64], vf=[128, 64])],
115
+ features_extractor_class=CustomLN,
116
  features_extractor_kwargs=dict(features_dim=64)
117
  )
118