DeepMostInnovations commited on
Commit
51c0731
·
verified ·
1 Parent(s): 91eebb2

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +635 -0
train.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile train.py
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ import numpy as np
6
+ from typing import List, Dict, Tuple, Optional, Any
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from sklearn.model_selection import train_test_split
12
+ from stable_baselines3 import PPO
13
+ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
14
+ from stable_baselines3.common.utils import set_random_seed
15
+ from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
16
+ from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
17
+ import gymnasium as gym
18
+ from gymnasium import spaces
19
+ from dataclasses import dataclass
20
+ import logging
21
+ import random
22
+ from tqdm import tqdm
23
+ import time
24
+ import matplotlib.pyplot as plt
25
+ import seaborn as sns
26
+ from datetime import datetime
27
+ import argparse
28
+ import psutil
29
+ import gc
30
+
31
+ # Configure logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
35
+ handlers=[
36
+ logging.FileHandler("sales_training.log"),
37
+ logging.StreamHandler()
38
+ ]
39
+ )
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # GPU Setup
44
+ if torch.cuda.is_available():
45
+ device = torch.device("cuda")
46
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
47
+ else:
48
+ device = torch.device("cpu")
49
+ logger.info("GPU not available, using CPU")
50
+
51
+ @dataclass
52
+ class ConversationState:
53
+ """Represents the state of a sales conversation for the RL environment."""
54
+ conversation_history: List[Dict[str, str]]
55
+ embedding: np.ndarray
56
+ conversation_metrics: Dict[str, float]
57
+ turn_number: int
58
+ conversion_probabilities: List[float]
59
+
60
+ @property
61
+ def state_vector(self) -> np.ndarray:
62
+ """Create a flat vector representation of the conversation state."""
63
+ # Combine embedding with conversation metrics and history stats
64
+ metric_values = np.array(list(self.conversation_metrics.values()), dtype=np.float32)
65
+ turn_info = np.array([self.turn_number], dtype=np.float32)
66
+ prob_history = np.array(self.conversion_probabilities, dtype=np.float32)
67
+
68
+ # Pad probability history to a fixed size if needed
69
+ padded_probs = np.zeros(10, dtype=np.float32)
70
+ padded_probs[:len(prob_history)] = prob_history[-10:] if len(prob_history) > 10 else prob_history
71
+
72
+ return np.concatenate([
73
+ self.embedding,
74
+ metric_values,
75
+ turn_info,
76
+ padded_probs
77
+ ])
78
+
79
+ # Custom neural network for feature extraction - optimized for GPU
80
+ class CustomLN(BaseFeaturesExtractor):
81
+ """Custom feature extractor for the embedding vector using linear layers."""
82
+
83
+ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
84
+ super().__init__(observation_space, features_dim)
85
+
86
+ # Get the input dimension from the observation space
87
+ n_input_channels = observation_space.shape[0]
88
+
89
+ # Create a network with linear layers
90
+ self.linear_network = nn.Sequential(
91
+ nn.Linear(n_input_channels, 512),
92
+ nn.ReLU(),
93
+ nn.Linear(512, 256),
94
+ nn.ReLU(),
95
+ nn.Linear(256, features_dim),
96
+ nn.ReLU(),
97
+ ).to(device)
98
+
99
+ def forward(self, observations: torch.Tensor) -> torch.Tensor:
100
+ return self.linear_network(observations)
101
+
102
+ class SalesConversionEnv(gym.Env):
103
+ """Reinforcement learning environment for sales conversation prediction."""
104
+
105
+ def __init__(self, conversations_df: pd.DataFrame, use_miniembeddings=True):
106
+ """
107
+ Initialize the environment.
108
+
109
+ Args:
110
+ conversations_df: DataFrame containing sales conversations
111
+ use_miniembeddings: If True, reduce embedding dimension to save memory
112
+ """
113
+ super().__init__()
114
+
115
+ self.conversations_df = conversations_df
116
+ self.current_conversation_idx = 0
117
+ self.max_turns = 20
118
+ self.use_miniembeddings = use_miniembeddings
119
+
120
+ # Get embedding dimension
121
+ embedding_cols = [col for col in conversations_df.columns if col.startswith('embedding_')]
122
+ self.full_embedding_dim = len(embedding_cols)
123
+
124
+ # Option to use reduced embedding dimension to save memory
125
+ if use_miniembeddings:
126
+ self.embedding_dim = min(1024, self.full_embedding_dim) # Use 1024 instead of 256
127
+ logger.info(f"Using reduced embeddings: {self.full_embedding_dim} -> {self.embedding_dim}")
128
+ else:
129
+ self.embedding_dim = self.full_embedding_dim
130
+
131
+ # Action space: Probability of conversion (0-1)
132
+ self.action_space = spaces.Box(
133
+ low=np.array([0.0]),
134
+ high=np.array([1.0]),
135
+ dtype=np.float32
136
+ )
137
+
138
+ # Observation space: Embeddings + metrics + turn info + probability history
139
+ self.observation_space = spaces.Box(
140
+ low=-np.inf,
141
+ high=np.inf,
142
+ shape=(self.embedding_dim + 5 + 1 + 10,), # Embeddings + 5 metrics + turn number + prob history
143
+ dtype=np.float32
144
+ )
145
+
146
+ self.current_turn = 0
147
+ self.conversation_state = None
148
+ self.true_probabilities = None
149
+
150
+ logger.info(f"Initialized SalesConversionEnv with {len(conversations_df)} conversations")
151
+
152
+ def _parse_conversation(self, conversation_idx: int) -> Tuple[List[Dict[str, str]], Dict[str, float], Dict[int, float]]:
153
+ """Parse conversation data from the dataset."""
154
+ row = self.conversations_df.iloc[conversation_idx]
155
+
156
+ # Parse messages
157
+ try:
158
+ messages = json.loads(row['conversation'])
159
+ except (json.JSONDecodeError, TypeError) as e:
160
+ # Create a fallback simple conversation
161
+ messages = [
162
+ {"speaker": "customer", "message": "I'm interested in your product."},
163
+ {"speaker": "sales_rep", "message": "Thank you for your interest. How can I help?"}
164
+ ]
165
+
166
+ # Parse metrics
167
+ metrics = {
168
+ 'customer_engagement': float(row.get('customer_engagement', 0.5)),
169
+ 'sales_effectiveness': float(row.get('sales_effectiveness', 0.5)),
170
+ 'conversation_length': int(row.get('conversation_length', len(messages))),
171
+ 'outcome': float(row.get('outcome', 0.5)),
172
+ 'progress': 0.0 # Will be updated during stepping
173
+ }
174
+
175
+ # Parse probability trajectory
176
+ try:
177
+ probability_trajectory = json.loads(row['probability_trajectory'])
178
+ # Convert string keys to integers
179
+ probability_trajectory = {int(k): float(v) for k, v in probability_trajectory.items()}
180
+ except (json.JSONDecodeError, TypeError, KeyError) as e:
181
+ # If no trajectory or error, create a simple one
182
+ if row.get('outcome', 0) == 1:
183
+ probability_trajectory = {i: min(0.5 + i * 0.05, 0.95) for i in range(len(messages))}
184
+ else:
185
+ probability_trajectory = {i: max(0.5 - i * 0.05, 0.05) for i in range(len(messages))}
186
+
187
+ return messages, metrics, probability_trajectory
188
+
189
+ def _get_embedding_for_turn(self, conversation_idx: int, turn: int) -> np.ndarray:
190
+ """Get the embedding for a specific conversation at a specific turn."""
191
+ row = self.conversations_df.iloc[conversation_idx]
192
+
193
+ # Get all embedding values
194
+ embedding_cols = [col for col in row.index if col.startswith('embedding_')]
195
+ try:
196
+ embedding = row[embedding_cols].values.astype(np.float32)
197
+
198
+ # Check for NaN or Inf values
199
+ if np.isnan(embedding).any() or np.isinf(embedding).any():
200
+ embedding = np.zeros(len(embedding_cols), dtype=np.float32)
201
+ except Exception as e:
202
+ embedding = np.zeros(len(embedding_cols), dtype=np.float32)
203
+
204
+ # Use dimensionality reduction for very large embeddings to save memory
205
+ if self.use_miniembeddings and len(embedding) > self.embedding_dim:
206
+ # Simple dimensionality reduction - average pooling
207
+ embedding = np.array([
208
+ np.mean(embedding[i:i+self.full_embedding_dim//self.embedding_dim])
209
+ for i in range(0, self.full_embedding_dim, self.full_embedding_dim//self.embedding_dim)
210
+ ][:self.embedding_dim])
211
+
212
+ # Simple scaling based on turn progress to simulate evolving embeddings
213
+ progress = min(1.0, turn / self.max_turns)
214
+ scaled_embedding = embedding * (0.6 + 0.4 * progress)
215
+
216
+ return scaled_embedding
217
+
218
+ def reset(self, seed=None, options=None) -> Tuple[np.ndarray, Dict]:
219
+ """Reset the environment to start a new episode."""
220
+ super().reset(seed=seed)
221
+
222
+ # Select a random conversation
223
+ self.current_conversation_idx = np.random.randint(0, len(self.conversations_df))
224
+ self.current_turn = 0
225
+
226
+ # Parse conversation data
227
+ messages, metrics, probability_trajectory = self._parse_conversation(self.current_conversation_idx)
228
+ self.true_probabilities = probability_trajectory
229
+ self.max_turns = min(20, len(messages))
230
+
231
+ # Initialize state
232
+ embedding = self._get_embedding_for_turn(self.current_conversation_idx, 0)
233
+ metrics = metrics.copy()
234
+ metrics['progress'] = 0.0
235
+
236
+ self.conversation_state = ConversationState(
237
+ conversation_history=messages[:1] if messages else [],
238
+ embedding=embedding,
239
+ conversation_metrics=metrics,
240
+ turn_number=0,
241
+ conversion_probabilities=[self.true_probabilities.get(0, 0.5)]
242
+ )
243
+
244
+ return self.conversation_state.state_vector, {}
245
+
246
+ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
247
+ """Step the environment forward by one turn."""
248
+ # Extract predicted probability
249
+ predicted_prob = float(action[0])
250
+
251
+ # Get true probability for current turn
252
+ true_prob = self.true_probabilities.get(self.current_turn, 0.5)
253
+
254
+ # Calculate reward based on prediction accuracy
255
+ reward = 1.0 - abs(predicted_prob - true_prob)
256
+
257
+ # Apply higher reward/penalty at final step based on outcome
258
+ if self.current_turn == self.max_turns - 1:
259
+ outcome = self.conversation_state.conversation_metrics['outcome']
260
+ # Stronger penalty for confident wrong predictions
261
+ if outcome == 1 and predicted_prob < 0.5:
262
+ reward -= 1.0 * (0.5 - predicted_prob)
263
+ elif outcome == 0 and predicted_prob > 0.5:
264
+ reward -= 1.0 * (predicted_prob - 0.5)
265
+
266
+ # Update turn
267
+ self.current_turn += 1
268
+ done = self.current_turn >= self.max_turns
269
+
270
+ if not done:
271
+ # Update state
272
+ embedding = self._get_embedding_for_turn(self.current_conversation_idx, self.current_turn)
273
+ metrics = self.conversation_state.conversation_metrics.copy()
274
+ metrics['progress'] = self.current_turn / self.max_turns
275
+
276
+ messages = self._parse_conversation(self.current_conversation_idx)[0]
277
+ history = messages[:self.current_turn+1] if self.current_turn+1 < len(messages) else messages
278
+
279
+ # Add current prediction to history
280
+ conv_probs = self.conversation_state.conversion_probabilities.copy()
281
+ conv_probs.append(predicted_prob)
282
+
283
+ self.conversation_state = ConversationState(
284
+ conversation_history=history,
285
+ embedding=embedding,
286
+ conversation_metrics=metrics,
287
+ turn_number=self.current_turn,
288
+ conversion_probabilities=conv_probs
289
+ )
290
+
291
+ return self.conversation_state.state_vector, reward, done, False, {'true_prob': true_prob}
292
+
293
+ class SalesRLTrainer:
294
+ """Trainer for the sales conversion prediction RL model."""
295
+
296
+ def __init__(self, dataset_path: str, model_save_path: str = "sales_conversion_model",
297
+ use_miniembeddings: bool = True, batch_size: int = 64):
298
+ """
299
+ Initialize the trainer.
300
+
301
+ Args:
302
+ dataset_path: Path to the sales conversation dataset
303
+ model_save_path: Path to save trained model
304
+ use_miniembeddings: Whether to use reduced embeddings to save memory
305
+ batch_size: Batch size for training
306
+ """
307
+ self.dataset_path = dataset_path
308
+ self.model_save_path = model_save_path
309
+ self.use_miniembeddings = use_miniembeddings
310
+ self.batch_size = batch_size
311
+ self.df = None
312
+ self.model = None
313
+ self.train_df = None
314
+ self.val_df = None
315
+
316
+ # Create directory for models and logs
317
+ os.makedirs(os.path.dirname(model_save_path) if os.path.dirname(model_save_path) else ".", exist_ok=True)
318
+ os.makedirs("logs", exist_ok=True)
319
+
320
+ logger.info(f"Initialized SalesRLTrainer with dataset: {dataset_path}")
321
+
322
+ # Monitor memory usage
323
+ self._log_memory_usage("Initial")
324
+
325
+ def _log_memory_usage(self, step=""):
326
+ """Log current memory usage."""
327
+ process = psutil.Process(os.getpid())
328
+ cpu_mem = process.memory_info().rss / 1024 / 1024 # MB
329
+
330
+ gpu_mem = 0
331
+ if torch.cuda.is_available():
332
+ gpu_mem = torch.cuda.memory_allocated() / 1024 / 1024 # MB
333
+
334
+ logger.info(f"Memory usage [{step}] - CPU: {cpu_mem:.2f} MB, GPU: {gpu_mem:.2f} MB")
335
+
336
+ def load_dataset(self, validation_split=0.1, sample_size=None):
337
+ """
338
+ Load and preprocess the sales conversation dataset.
339
+
340
+ Args:
341
+ validation_split: Proportion of data for validation
342
+ sample_size: Optional limit on dataset size to save memory
343
+ """
344
+ logger.info(f"Loading dataset from {self.dataset_path}")
345
+ try:
346
+ # Read dataset in chunks to reduce memory usage
347
+ chunks = []
348
+ for chunk in pd.read_csv(self.dataset_path, chunksize=10000):
349
+ chunks.append(chunk)
350
+
351
+ # If sample size specified, break after enough chunks
352
+ if sample_size and sum(len(c) for c in chunks) >= sample_size:
353
+ break
354
+
355
+ self.df = pd.concat(chunks)
356
+
357
+ # If sample size specified, limit the dataset
358
+ if sample_size and len(self.df) > sample_size:
359
+ self.df = self.df.sample(sample_size, random_state=42)
360
+
361
+ logger.info(f"Loaded dataset with shape: {self.df.shape}")
362
+
363
+ # Validate embedding columns
364
+ embedding_cols = [col for col in self.df.columns if col.startswith('embedding_')]
365
+ if not embedding_cols:
366
+ raise ValueError("No embedding columns found in the dataset")
367
+
368
+ logger.info(f"Found {len(embedding_cols)} embedding dimensions")
369
+
370
+ # Clean the dataframe to reduce memory usage
371
+ for col in self.df.columns:
372
+ if col.startswith('embedding_'):
373
+ # Convert embedding columns to float32
374
+ self.df[col] = self.df[col].astype(np.float32)
375
+ elif col in ['outcome', 'customer_engagement', 'sales_effectiveness']:
376
+ # Convert numeric columns to float32
377
+ self.df[col] = self.df[col].astype(np.float32)
378
+ elif col == 'conversation_length':
379
+ # Convert to int32
380
+ self.df[col] = self.df[col].astype(np.int32)
381
+
382
+ # Split into train and validation sets
383
+ train_idx, val_idx = train_test_split(
384
+ np.arange(len(self.df)),
385
+ test_size=validation_split,
386
+ random_state=42
387
+ )
388
+
389
+ self.train_df = self.df.iloc[train_idx].reset_index(drop=True)
390
+ self.val_df = self.df.iloc[val_idx].reset_index(drop=True)
391
+
392
+ logger.info(f"Split dataset: {len(self.train_df)} training samples, {len(self.val_df)} validation samples")
393
+
394
+ # Monitor memory
395
+ self._log_memory_usage("After dataset load")
396
+
397
+ # Free up memory
398
+ gc.collect()
399
+
400
+ except Exception as e:
401
+ logger.error(f"Error loading dataset: {str(e)}")
402
+ raise
403
+
404
+ def train(self, total_timesteps: int = 100000, learning_rate: float = 0.0003, n_envs: int = 1):
405
+ """
406
+ Train the RL model with GPU acceleration.
407
+
408
+ Args:
409
+ total_timesteps: Total timesteps for training
410
+ learning_rate: Learning rate for the optimizer
411
+ n_envs: Number of parallel environments
412
+ """
413
+ if self.train_df is None:
414
+ self.load_dataset()
415
+
416
+ # Use only 1 environment with GPU for better memory efficiency
417
+ n_envs = 1 if torch.cuda.is_available() else n_envs
418
+
419
+ # Create training environment
420
+ def make_env(df_subset):
421
+ """Create environment with a subset of data."""
422
+ def _init():
423
+ return SalesConversionEnv(df_subset, use_miniembeddings=self.use_miniembeddings)
424
+ return _init
425
+
426
+ # Create subsets of data for each environment
427
+ if n_envs > 1:
428
+ subset_size = len(self.train_df) // n_envs
429
+ env_makers = [
430
+ make_env(self.train_df.iloc[i*subset_size:(i+1)*subset_size if i < n_envs-1 else len(self.train_df)])
431
+ for i in range(n_envs)
432
+ ]
433
+ env = SubprocVecEnv(env_makers)
434
+ else:
435
+ env = DummyVecEnv([make_env(self.train_df)])
436
+
437
+ # Create validation environment
438
+ val_env = DummyVecEnv([make_env(self.val_df)])
439
+
440
+ # Configure policy network
441
+ policy_kwargs = dict(
442
+ activation_fn=nn.ReLU,
443
+ net_arch=[dict(pi=[128, 64], vf=[128, 64])], # Smaller network to save memory
444
+ features_extractor_class=CustomLN,
445
+ features_extractor_kwargs=dict(features_dim=64)
446
+ )
447
+
448
+ # Initialize model with GPU support
449
+ self.model = PPO(
450
+ "MlpPolicy",
451
+ env,
452
+ policy_kwargs=policy_kwargs,
453
+ learning_rate=learning_rate,
454
+ n_steps=512, # Smaller n_steps to save memory
455
+ batch_size=self.batch_size,
456
+ n_epochs=5, # Fewer epochs to speed up training
457
+ gamma=0.99,
458
+ gae_lambda=0.95,
459
+ clip_range=0.2,
460
+ clip_range_vf=0.2,
461
+ ent_coef=0.01,
462
+ vf_coef=0.5,
463
+ max_grad_norm=0.5,
464
+ tensorboard_log="./logs/",
465
+ verbose=1,
466
+ device=device # Use GPU if available
467
+ )
468
+
469
+ # Set up callbacks
470
+ eval_callback = EvalCallback(
471
+ val_env,
472
+ best_model_save_path=f"{os.path.dirname(self.model_save_path)}/best_model",
473
+ log_path="./logs/",
474
+ eval_freq=max(2000, total_timesteps // 20), # Evaluate less frequently to save time
475
+ deterministic=True,
476
+ render=False
477
+ )
478
+
479
+ checkpoint_callback = CheckpointCallback(
480
+ save_freq=max(5000, total_timesteps // 10), # Save less frequently to reduce I/O
481
+ save_path="./logs/checkpoints/",
482
+ name_prefix="sales_model",
483
+ save_replay_buffer=False,
484
+ save_vecnormalize=False
485
+ )
486
+
487
+ # Monitor memory before training
488
+ self._log_memory_usage("Before training")
489
+
490
+ logger.info(f"Starting training for {total_timesteps} timesteps with {n_envs} environments on {device}")
491
+ self.model.learn(
492
+ total_timesteps=total_timesteps,
493
+ callback=[eval_callback, checkpoint_callback],
494
+ progress_bar=True
495
+ )
496
+
497
+ # Save final model
498
+ self.model.save(self.model_save_path)
499
+ logger.info(f"Model saved to {self.model_save_path}")
500
+
501
+ # Monitor memory after training
502
+ self._log_memory_usage("After training")
503
+
504
+ # Clean up to free memory
505
+ env.close()
506
+ val_env.close()
507
+ gc.collect()
508
+ if torch.cuda.is_available():
509
+ torch.cuda.empty_cache()
510
+
511
+ def evaluate(self, num_episodes: int = 100):
512
+ """Evaluate the trained model."""
513
+ if self.model is None:
514
+ logger.info(f"Loading model from {self.model_save_path}")
515
+ self.model = PPO.load(self.model_save_path, device=device)
516
+
517
+ if self.val_df is None:
518
+ self.load_dataset()
519
+
520
+ # Create environment
521
+ env = SalesConversionEnv(self.val_df, use_miniembeddings=self.use_miniembeddings)
522
+
523
+ logger.info(f"Evaluating model on {num_episodes} episodes")
524
+
525
+ rewards = []
526
+ accuracies = []
527
+ predictions = []
528
+ true_outcomes = []
529
+
530
+ for i in tqdm(range(num_episodes), desc="Evaluating"):
531
+ obs, _ = env.reset()
532
+ done = False
533
+ episode_reward = 0
534
+ episode_predictions = []
535
+ true_values = []
536
+
537
+ while not done:
538
+ action, _ = self.model.predict(obs, deterministic=True)
539
+ obs, reward, done, _, info = env.step(action)
540
+
541
+ episode_reward += reward
542
+ episode_predictions.append(float(action[0]))
543
+ true_values.append(info['true_prob'])
544
+
545
+ rewards.append(episode_reward)
546
+
547
+ # Calculate accuracy based on final prediction
548
+ final_pred = episode_predictions[-1]
549
+ outcome = env.conversation_state.conversation_metrics['outcome']
550
+ correct = (final_pred >= 0.5 and outcome == 1) or (final_pred < 0.5 and outcome == 0)
551
+ accuracies.append(int(correct))
552
+
553
+ predictions.append(final_pred)
554
+ true_outcomes.append(1 if outcome >= 0.5 else 0)
555
+
556
+ mean_reward = np.mean(rewards)
557
+ mean_accuracy = np.mean(accuracies)
558
+
559
+ # Calculate additional metrics
560
+ true_positives = sum(1 for p, t in zip(predictions, true_outcomes) if p >= 0.5 and t == 1)
561
+ false_positives = sum(1 for p, t in zip(predictions, true_outcomes) if p >= 0.5 and t == 0)
562
+ true_negatives = sum(1 for p, t in zip(predictions, true_outcomes) if p < 0.5 and t == 0)
563
+ false_negatives = sum(1 for p, t in zip(predictions, true_outcomes) if p < 0.5 and t == 1)
564
+
565
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
566
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
567
+ f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
568
+
569
+ logger.info(f"Evaluation results:")
570
+ logger.info(f"- Mean reward: {mean_reward:.4f}")
571
+ logger.info(f"- Prediction accuracy: {mean_accuracy:.4f}")
572
+ logger.info(f"- Precision: {precision:.4f}")
573
+ logger.info(f"- Recall: {recall:.4f}")
574
+ logger.info(f"- F1 Score: {f1_score:.4f}")
575
+
576
+ return {
577
+ 'mean_reward': float(mean_reward),
578
+ 'accuracy': float(mean_accuracy),
579
+ 'precision': float(precision),
580
+ 'recall': float(recall),
581
+ 'f1_score': float(f1_score)
582
+ }
583
+
584
+ def main():
585
+ """Main function to run the training pipeline."""
586
+ parser = argparse.ArgumentParser(description="Train a sales conversion prediction model")
587
+ parser.add_argument("--dataset", type=str, required=True,
588
+ help="Path to the dataset CSV file")
589
+ parser.add_argument("--model_path", type=str, default="models/sales_conversion_model",
590
+ help="Path to save the trained model")
591
+ parser.add_argument("--timesteps", type=int, default=50000,
592
+ help="Number of timesteps to train for")
593
+ parser.add_argument("--learning_rate", type=float, default=0.0003,
594
+ help="Learning rate for training")
595
+ parser.add_argument("--batch_size", type=int, default=64,
596
+ help="Batch size for training")
597
+ parser.add_argument("--sample_size", type=int, default=None,
598
+ help="Limit dataset size to save memory (e.g., 10000)")
599
+ parser.add_argument("--evaluate_only", action="store_true",
600
+ help="Only evaluate an existing model without training")
601
+ parser.add_argument("--num_eval_episodes", type=int, default=50,
602
+ help="Number of episodes for evaluation")
603
+ parser.add_argument("--use_small_embedding", action="store_true",
604
+ help="Use reduced embedding dimension to save memory")
605
+
606
+ args = parser.parse_args()
607
+
608
+ # Initialize trainer
609
+ trainer = SalesRLTrainer(
610
+ dataset_path=args.dataset,
611
+ model_save_path=args.model_path,
612
+ use_miniembeddings=args.use_small_embedding,
613
+ batch_size=args.batch_size
614
+ )
615
+
616
+ # Load dataset with optional sample limit
617
+ trainer.load_dataset(sample_size=args.sample_size)
618
+
619
+ # Train or evaluate
620
+ if not args.evaluate_only:
621
+ trainer.train(
622
+ total_timesteps=args.timesteps,
623
+ learning_rate=args.learning_rate
624
+ )
625
+
626
+ # Evaluate
627
+ eval_results = trainer.evaluate(num_episodes=args.num_eval_episodes)
628
+
629
+ # Print evaluation results
630
+ print("\nEvaluation Results:")
631
+ for metric, value in eval_results.items():
632
+ print(f"- {metric}: {value:.4f}")
633
+
634
+ if __name__ == "__main__":
635
+ main()