|
|
import numpy as np |
|
|
|
|
|
|
|
|
class FraudDetectionEnv(gym.Env): |
|
|
""" |
|
|
A custom Gym environment for Fraud Detection using embeddings. |
|
|
|
|
|
State: Embedding of a transaction. |
|
|
Action: 0 (Declare Not Fraud), 1 (Declare Fraud). |
|
|
Reward: Based on correctly/incorrectly classifying fraud vs non-fraud. |
|
|
""" |
|
|
def __init__(self, embeddings: np.ndarray, labels: np.ndarray, reward_config: dict): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
assert embeddings.shape[0] == labels.shape[0], "Embeddings and labels must have the same number of instances." |
|
|
assert embeddings.shape[1] == 768, f"Embeddings must be 768-dimensional, but got {embeddings.shape[1]}" |
|
|
|
|
|
self.embeddings = embeddings.astype(np.float32) |
|
|
self.labels = labels.astype(np.int64) |
|
|
|
|
|
self.num_instances = self.embeddings.shape[0] |
|
|
self.reward_config = reward_config |
|
|
|
|
|
|
|
|
|
|
|
self.action_space = spaces.Discrete(2) |
|
|
|
|
|
|
|
|
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(768,), dtype=np.float32) |
|
|
|
|
|
|
|
|
self._current_index = 0 |
|
|
self._order = np.arange(self.num_instances) |
|
|
np.random.shuffle(self._order) |
|
|
|
|
|
|
|
|
def step(self, action: int): |
|
|
|
|
|
if self._current_index >= self.num_instances: |
|
|
print("Warning: step() called when episode is already done.") |
|
|
return self.observation_space.sample() * 0, 0, True, False, {} |
|
|
|
|
|
|
|
|
actual_index = self._order[self._current_index] |
|
|
current_embedding = self.embeddings[actual_index] |
|
|
true_label = self.labels[actual_index] |
|
|
|
|
|
|
|
|
reward = 0 |
|
|
if action == 1 and true_label == 1: |
|
|
reward = self.reward_config.get('TP', 0) |
|
|
elif action == 1 and true_label == 0: |
|
|
reward = self.reward_config.get('FP', 0) |
|
|
elif action == 0 and true_label == 1: |
|
|
reward = self.reward_config.get('FN', 0) |
|
|
elif action == 0 and true_label == 0: |
|
|
reward = self.reward_config.get('TN', 0) |
|
|
|
|
|
|
|
|
self._current_index += 1 |
|
|
|
|
|
|
|
|
done = self._current_index >= self.num_instances |
|
|
truncated = False |
|
|
|
|
|
|
|
|
next_observation = np.zeros_like(current_embedding, dtype=np.float32) |
|
|
if not done: |
|
|
next_observation = self.embeddings[self._order[self._current_index]] |
|
|
|
|
|
info = { |
|
|
'true_label': true_label, |
|
|
'predicted_action': action, |
|
|
'instance_uid': actual_index, |
|
|
'is_done': done |
|
|
} |
|
|
|
|
|
return next_observation, reward, done, truncated, info |
|
|
|
|
|
|
|
|
def reset(self, seed=None, options=None): |
|
|
super().reset(seed=seed) |
|
|
|
|
|
|
|
|
self._current_index = 0 |
|
|
self._order = np.arange(self.num_instances) |
|
|
self.np_random.shuffle(self._order) |
|
|
|
|
|
|
|
|
initial_observation = self.embeddings[self._order[self._current_index]] |
|
|
|
|
|
info = {'instance_uid': self._order[self._current_index]} |
|
|
|
|
|
return initial_observation, info |
|
|
|
|
|
def close(self): |
|
|
|
|
|
pass |