Update README.md
Browse files
README.md
CHANGED
|
@@ -41,7 +41,7 @@ This is a reinforcement learning model trained to predict real-time sales conver
|
|
| 41 |
- **Framework**: Stable Baselines3 (PPO)
|
| 42 |
- **State Representation**: Azure OpenAI embeddings
|
| 43 |
- **Action Space**: Continuous (conversion probability 0-1)
|
| 44 |
-
- **Feature Extractor**: Custom
|
| 45 |
|
| 46 |
## Quick Start
|
| 47 |
|
|
@@ -83,12 +83,12 @@ model_path = hf_hub_download(
|
|
| 83 |
# Check for GPU
|
| 84 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
|
| 86 |
-
# Custom
|
| 87 |
-
class
|
| 88 |
def __init__(self, observation_space, features_dim: int = 64):
|
| 89 |
super().__init__(observation_space, features_dim)
|
| 90 |
n_input_channels = observation_space.shape[0]
|
| 91 |
-
self.
|
| 92 |
nn.Linear(n_input_channels, 512),
|
| 93 |
nn.ReLU(),
|
| 94 |
nn.Linear(512, 256),
|
|
@@ -98,7 +98,7 @@ class CustomCNN(BaseFeaturesExtractor):
|
|
| 98 |
).to(device)
|
| 99 |
|
| 100 |
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
| 101 |
-
return self.
|
| 102 |
|
| 103 |
@dataclass
|
| 104 |
class SalesAgent:
|
|
@@ -112,7 +112,7 @@ class SalesAgent:
|
|
| 112 |
policy_kwargs = dict(
|
| 113 |
activation_fn=nn.ReLU,
|
| 114 |
net_arch=[dict(pi=[128, 64], vf=[128, 64])],
|
| 115 |
-
features_extractor_class=
|
| 116 |
features_extractor_kwargs=dict(features_dim=64)
|
| 117 |
)
|
| 118 |
|
|
|
|
| 41 |
- **Framework**: Stable Baselines3 (PPO)
|
| 42 |
- **State Representation**: Azure OpenAI embeddings
|
| 43 |
- **Action Space**: Continuous (conversion probability 0-1)
|
| 44 |
+
- **Feature Extractor**: Custom Linear layers
|
| 45 |
|
| 46 |
## Quick Start
|
| 47 |
|
|
|
|
| 83 |
# Check for GPU
|
| 84 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
|
| 86 |
+
# Custom Linear Layer class
|
| 87 |
+
class CustomLN(BaseFeaturesExtractor):
|
| 88 |
def __init__(self, observation_space, features_dim: int = 64):
|
| 89 |
super().__init__(observation_space, features_dim)
|
| 90 |
n_input_channels = observation_space.shape[0]
|
| 91 |
+
self.ln = nn.Sequential(
|
| 92 |
nn.Linear(n_input_channels, 512),
|
| 93 |
nn.ReLU(),
|
| 94 |
nn.Linear(512, 256),
|
|
|
|
| 98 |
).to(device)
|
| 99 |
|
| 100 |
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
return self.ln(observations)
|
| 102 |
|
| 103 |
@dataclass
|
| 104 |
class SalesAgent:
|
|
|
|
| 112 |
policy_kwargs = dict(
|
| 113 |
activation_fn=nn.ReLU,
|
| 114 |
net_arch=[dict(pi=[128, 64], vf=[128, 64])],
|
| 115 |
+
features_extractor_class=CustomLN,
|
| 116 |
features_extractor_kwargs=dict(features_dim=64)
|
| 117 |
)
|
| 118 |
|