CrisisSim / app /env.py
TanmaySK's picture
Update app/env.py
0fd48cb verified
from typing import *
import random
import math
import math
from copy import deepcopy
from app.models import ActionEnum, EventEnum, Observation, Reward
class CrisisSimEnv:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.max_months = config.get("max_months", 12)
self.task_difficulty = config.get("difficulty", "medium")
self.reset()
def reset(self) -> Observation:
self.month = 0
# Initial states based on difficulty
self.income = 3000.0
self.expenses = 2000.0
self.savings = 5000.0
self.debt = 1000.0
self.inflation = 0.02
self.currency_value = 1.0 # Base 1.0
self.food_price_index = 100.0
self.fuel_price = 3.0 # per unit
self.current_event = EventEnum.none
self.bankrupt = False
if self.task_difficulty == "easy":
self.savings = 8000.0
self.debt = 500.0
elif self.task_difficulty == "medium":
self.inflation = 0.05
elif self.task_difficulty == "hard":
self.inflation = 0.08
self.savings = 2000.0
self.debt = 3000.0
# Metrics for reward calculation
self.initial_savings = self.savings
self.initial_debt = self.debt
self.smart_decisions = 0
self.bad_decisions = 0
self.consecutive_negative_months = 0
return self.state()
def state(self) -> Observation:
return Observation(
income=self.income,
expenses=self.expenses,
savings=self.savings,
debt=self.debt,
inflation=self.inflation,
currency_value=self.currency_value,
food_price_index=self.food_price_index,
fuel_price=self.fuel_price,
current_event=self.current_event.value,
month=self.month,
bankrupt=self.bankrupt
)
def _apply_action(self, action: ActionEnum):
# Default action behavior
if action == ActionEnum.cut_expenses:
self.expenses = max(1000.0, self.expenses - 300.0)
self.smart_decisions += 1
elif action == ActionEnum.stock_essentials:
self.savings -= 500.0
self.expenses += 100.0 # higher maintenance
# Provides buffer against food price index
self.smart_decisions += 1
elif action == ActionEnum.invest_gold:
self.savings -= 1000.0
self.smart_decisions += 1
elif action == ActionEnum.hold_cash:
pass # No direct change, safe but vulnerable to inflation
elif action == ActionEnum.convert_currency:
self.savings -= 50.0 # fee
self.smart_decisions += 1
elif action == ActionEnum.take_loan:
self.savings += 2000.0
self.debt += 2200.0 # interest
self.bad_decisions += 1
elif action == ActionEnum.pay_debt:
amount = min(self.savings, self.debt)
self.savings -= amount
self.debt -= amount
if amount > 0:
self.smart_decisions += 1
elif action == ActionEnum.reduce_luxury:
self.expenses -= 500.0
self.smart_decisions += 1
elif action == ActionEnum.build_emergency_fund:
self.savings += 500.0
self.expenses += 500.0 # moving from income stream effectively, abstract logic: increase expenses, keep savings higher
self.smart_decisions += 1
def _trigger_event(self):
# Event probabilities based on difficulty
events = [EventEnum.none]
weights = [1.0]
if self.task_difficulty == "easy":
events.extend([EventEnum.job_loss, EventEnum.currency_crash])
weights.extend([0.05, 0.05])
elif self.task_difficulty == "medium":
events.extend([EventEnum.oil_supply_shock, EventEnum.food_shortage])
weights = [0.4, 0.3, 0.3]
elif self.task_difficulty == "hard":
events.extend([EventEnum.war_outbreak, EventEnum.job_loss, EventEnum.currency_crash, EventEnum.import_ban])
weights = [0.1, 0.3, 0.2, 0.2, 0.2]
self.current_event = random.choices(events, weights=weights, k=1)[0]
def _apply_event(self):
# Soften severity for easy task
severity_mult = 0.5 if self.task_difficulty == "easy" else 1.0
if self.current_event == EventEnum.war_outbreak:
self.fuel_price *= (1.0 + 0.15 * severity_mult)
self.inflation += (0.02 * severity_mult)
elif self.current_event == EventEnum.oil_supply_shock:
self.fuel_price *= (1.0 + 0.15 * severity_mult)
elif self.current_event == EventEnum.currency_crash:
self.currency_value *= (1.0 - 0.15 * severity_mult)
self.inflation += (0.02 * severity_mult)
elif self.current_event == EventEnum.food_shortage:
self.food_price_index *= (1.0 + 0.20 * severity_mult)
elif self.current_event == EventEnum.job_loss:
self.income *= (1.0 - 0.70 * severity_mult) # 30% retention normally, 65% on easy
elif self.current_event == EventEnum.import_ban:
self.food_price_index *= (1.0 + 0.10 * severity_mult)
self.inflation += (0.01 * severity_mult)
def _update_economy(self):
# Cause-effect propagation
# Fuel price increases transport costs, making food more expensive
if self.fuel_price > 3.5:
self.food_price_index += (self.fuel_price - 3.5) * 5.0
# Food price -> inflation
if self.food_price_index > 120.0:
self.inflation += 0.01 * (self.food_price_index / 100.0)
# Apply inflation to expenses
self.expenses *= (1.0 + self.inflation)
# Debts accrue interest (e.g., 5% per month minimum)
self.debt *= 1.05
# Update savings
self.savings = self.savings + self.income - self.expenses
def _check_bankruptcy(self):
# Determine strictness
if self.task_difficulty == "hard":
consec_limit = 3
grace_buffer = 10000.0
elif self.task_difficulty == "medium":
consec_limit = 4
grace_buffer = 15000.0
else: # easy
consec_limit = 6
grace_buffer = 20000.0
# Register bad months vs recovery
if self.savings < -grace_buffer:
self.consecutive_negative_months += 1
else:
self.consecutive_negative_months = max(0, self.consecutive_negative_months - 1)
if self.consecutive_negative_months >= consec_limit:
self.bankrupt = True
def _compute_reward(self) -> float:
survival_score = 1.0 if not self.bankrupt else 0.0
# Soft scaling: asymptotic curves instead of hard caps
if self.savings > 0:
savings_ratio = math.tanh(self.savings / max(1.0, self.initial_savings))
else:
savings_ratio = -math.tanh(abs(self.savings) / 5000.0)
if self.initial_debt > 0:
debt_ratio = 1.0 - math.tanh(self.debt / max(1.0, self.initial_debt))
else:
debt_ratio = math.exp(-self.debt / 2000.0)
# State change dynamics (monthly deltas)
prev_savings = getattr(self, "previous_savings", self.savings)
prev_debt = getattr(self, "previous_debt", self.debt)
prev_inflation = getattr(self, "previous_inflation", self.inflation)
savings_delta = (self.savings - prev_savings) / 1000.0
debt_delta = (prev_debt - self.debt) / 1000.0
inflation_delta = (prev_inflation - self.inflation) * 20.0
# Small dynamic variation based on monthly momentum (+/- 0.05)
momentum_bonus = math.tanh(savings_delta + debt_delta + inflation_delta) * 0.05
smart_bonus_ratio = math.tanh(self.smart_decisions / 5.0)
bad_penalty_ratio = math.tanh(self.bad_decisions / 5.0)
# Gradual penalty over steps instead of instant max punishment
bad_state_penalty = self.consecutive_negative_months * 0.05
bankruptcy_penalty = 0.1 if self.bankrupt else 0.0
reward = (
survival_score * 0.30 +
savings_ratio * 0.15 +
debt_ratio * 0.15 +
smart_bonus_ratio * 0.15 -
bad_penalty_ratio * 0.15 -
bad_state_penalty -
bankruptcy_penalty +
momentum_bonus
)
# Add soft-scaled survival bonus per step
reward += math.tanh(self.month / 12.0) * 0.10
# Normalize reward Incrementally to [0,1] with a floor to prevent instant 0.0 drops
min_reward_floor = 0.15
# Soft clamp near the top to prevent flatlining at 1.00 or 0.90
if reward > 0.90:
reward = 0.90 + 0.10 * math.tanh((reward - 0.90) * 5.0)
normalized_reward = max(min_reward_floor, min(1.0, reward))
return normalized_reward
def step(self, action: ActionEnum) -> Tuple[Observation, float, bool, Dict[str, Any]]:
self.month += 1
# 1. Apply agent action
self._apply_action(action)
# 2. Random events
self._trigger_event()
# 3. Apply event impact
self._apply_event()
# 4. Update inflation and prices, savings
self._update_economy()
# 5. Check bankruptcy condition
self._check_bankruptcy()
# 6. Compute reward
reward = self._compute_reward()
# Store state variables to compare deltas next month
self.previous_savings = self.savings
self.previous_debt = self.debt
self.previous_inflation = self.inflation
# Check termination
done = self.month >= self.max_months or self.bankrupt
return self.state(), reward, done, {"bankrupt": self.bankrupt, "month": self.month}