LLM-RL-Fraud / false /custom_env.py
djaloul's picture
Upload folder using huggingface_hub
84bef82 verified
import numpy as np
class FraudDetectionEnv(gym.Env):
"""
A custom Gym environment for Fraud Detection using embeddings.
State: Embedding of a transaction.
Action: 0 (Declare Not Fraud), 1 (Declare Fraud).
Reward: Based on correctly/incorrectly classifying fraud vs non-fraud.
"""
def __init__(self, embeddings: np.ndarray, labels: np.ndarray, reward_config: dict):
super().__init__()
# Ensure data consistency
assert embeddings.shape[0] == labels.shape[0], "Embeddings and labels must have the same number of instances."
assert embeddings.shape[1] == 768, f"Embeddings must be 768-dimensional, but got {embeddings.shape[1]}"
self.embeddings = embeddings.astype(np.float32)
self.labels = labels.astype(np.int64)
self.num_instances = self.embeddings.shape[0]
self.reward_config = reward_config
# Define action and observation space
# Action Space: Discrete(2) -> 0 for Not Fraud, 1 for Fraud
self.action_space = spaces.Discrete(2)
# Observation Space: Box(low, high, shape, dtype) -> 768-dim vector
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(768,), dtype=np.float32)
# Internal state
self._current_index = 0
self._order = np.arange(self.num_instances)
np.random.shuffle(self._order) # Shuffle the order of instances initially
def step(self, action: int):
# Check if episode is done
if self._current_index >= self.num_instances:
print("Warning: step() called when episode is already done.")
return self.observation_space.sample() * 0, 0, True, False, {} # Return dummy values
# Get current instance data based on shuffled order
actual_index = self._order[self._current_index]
current_embedding = self.embeddings[actual_index]
true_label = self.labels[actual_index]
# Determine reward
reward = 0
if action == 1 and true_label == 1:
reward = self.reward_config.get('TP', 0)
elif action == 1 and true_label == 0:
reward = self.reward_config.get('FP', 0)
elif action == 0 and true_label == 1:
reward = self.reward_config.get('FN', 0)
elif action == 0 and true_label == 0:
reward = self.reward_config.get('TN', 0)
# Move to the next instance
self._current_index += 1
# Check if the episode is finished
done = self._current_index >= self.num_instances
truncated = False
# Get the next observation
next_observation = np.zeros_like(current_embedding, dtype=np.float32) # Default for done state
if not done:
next_observation = self.embeddings[self._order[self._current_index]]
info = {
'true_label': true_label,
'predicted_action': action,
'instance_uid': actual_index,
'is_done': done
}
return next_observation, reward, done, truncated, info
def reset(self, seed=None, options=None):
super().reset(seed=seed) # Handles seeding
# Reset index and shuffle order for a new episode
self._current_index = 0
self._order = np.arange(self.num_instances)
self.np_random.shuffle(self._order) # Use the environment's random number generator
# Get the first observation of the new episode
initial_observation = self.embeddings[self._order[self._current_index]]
info = {'instance_uid': self._order[self._current_index]}
return initial_observation, info
def close(self):
# Optional: Implement cleanup
pass