CrisisSim / inference.py
TanmaySK's picture
uploading done
b96c7d5 verified
import requests
import sys
import os
import json
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
BASE_URL = "http://localhost:7860"
def run_inference(task_name: str):
env_name = "CrisisSim"
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
if not API_KEY:
print(f"[START] task={task_name} env={env_name} model={MODEL_NAME}", flush=True)
print("[STEP] step=1 action=null reward=0.00 done=true error=API_KEY or HF_TOKEN environment variable is not set", flush=True)
print("[END] success=false steps=0 rewards=", flush=True)
sys.exit(1)
print(f"[START] task={task_name} env={env_name} model={MODEL_NAME}", flush=True)
try:
client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY
)
except Exception as e:
print(f"[STEP] step=1 action=null reward=0.00 done=true error=api_error", flush=True)
print("[END] success=false steps=0 rewards=", flush=True)
sys.exit(1)
steps = 0
rewards = []
last_actions = []
VALID_ACTIONS = [
"cut_expenses", "stock_essentials", "invest_gold", "hold_cash",
"convert_currency", "take_loan", "pay_debt", "reduce_luxury", "build_emergency_fund"
]
SYSTEM_PROMPT = """You are an expert financial crisis manager operating in a simulated economy.
Your goals:
- Survive as long as possible (avoid bankruptcy)
- Maintain and grow savings
- Reduce debt strategically
- Adapt to inflation and economic shocks
Rules:
- Do NOT repeat the same action more than 2 times in a row
- Avoid passive strategies like always holding cash
- Balance short-term survival with long-term stability
- Consider consequences of each decision
Always return ONLY one valid action."""
try:
# Reset Environment
res = requests.post(f"{BASE_URL}/reset", json={"task_name": task_name})
res.raise_for_status()
state = res.json()
while True:
steps += 1
error_val = "null"
user_prompt = f"""Current State:
{state}
Previous Actions (last 3):
{last_actions[-3:]}
Choose ONE action from:
[cut_expenses, stock_essentials, invest_gold, hold_cash, convert_currency, take_loan, pay_debt, reduce_luxury, build_emergency_fund]
Avoid repeating same action too often.
Return ONLY the action name."""
# Prompt the model
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT.strip()},
{"role": "user", "content": user_prompt.strip()}
],
temperature=0.5,
max_tokens=50
)
action = completion.choices[0].message.content.strip()
# fallback if invalid
if action not in VALID_ACTIONS:
action = "hold_cash"
# prevent repetition >2
if len(last_actions) >= 2 and last_actions[-1] == last_actions[-2] == action:
alternatives = [a for a in VALID_ACTIONS if a != action]
action = alternatives[0]
except Exception as e:
# SAFE FALLBACK
action = "hold_cash"
error_val = "api_error"
last_actions.append(action)
# Step the environment
step_res = requests.post(f"{BASE_URL}/step", json={"action": action})
step_res.raise_for_status()
data = step_res.json()
state = data["observation"]
reward = data["reward"]
done = data["done"]
rewards.append(reward)
print(f"[STEP] step={steps} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
if done:
success = not state.get("bankrupt", True)
break
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
except Exception as e:
error_val = "api_error"
print(f"[STEP] step={steps} action=null reward=0.00 done=true error={error_val}", flush=True)
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success=false steps={steps} rewards={rewards_str}", flush=True)
if __name__ == "__main__":
task = sys.argv[1] if len(sys.argv) > 1 else "easy"
run_inference(task)