Lilli98's picture
Update app.py
f9bdcf7 verified
raw
history blame
20.9 kB
# app.py
"""
Beer Game — Robust full Streamlit app (fixed pipeline/Retailer KeyError)
- Uses old openai SDK style (openai==0.28.0) to avoid proxies/new-client issues on Spaces
- Only uploads logs to HF at end of game
- Ensures missing keys are initialized for backward compatibility
- Unified lowercase role keys: 'retailer','wholesaler','distributor','factory'
"""
import os
import re
import time
import uuid
import random
import json
import traceback
from datetime import datetime
from pathlib import Path
import streamlit as st
import pandas as pd
import openai # expects openai==0.28.0 in requirements.txt
from huggingface_hub import upload_file, HfApi
# ---------------------------
# CONFIG
# ---------------------------
DEFAULT_WEEKS = 36 # 24 或 36 可选,默认 36
# Lead times
ORDER_LEAD_TIME = 1 # Time for orders to reach supplier
SHIPPING_LEAD_TIME = 2 # Time for shipments to arrive
PRODUCTION_LEAD_TIME = 2 # Time for factory to produce goods
INITIAL_INVENTORY = 12
INITIAL_BACKLOG = 0
OPENAI_MODEL = "gpt-4o-mini"
LOCAL_LOG_DIR = Path("logs")
LOCAL_LOG_DIR.mkdir(exist_ok=True)
# HF settings (via Secrets)
HF_TOKEN = os.getenv("HF_TOKEN")
HF_REPO_ID = os.getenv("HF_REPO_ID") # e.g. "Lilli98/beer-game-logs"
hf_api = HfApi()
# OpenAI key (old SDK usage)
openai.api_key = os.getenv("OPENAI_API_KEY")
# ---------------------------
# HELPERS
# ---------------------------
def now_iso():
return datetime.utcnow().isoformat(timespec="milliseconds") + "Z"
def make_classic_demand(weeks: int):
# first 4 weeks: 4, from week 5 onwards: 8 (classic shock)
demand = []
for t in range(weeks):
if t < 4:
demand.append(4)
else:
demand.append(8)
return demand
def fmt(o):
try:
return json.dumps(o, ensure_ascii=False)
except:
return str(o)
# ---------------------------
# STATE COMPATIBILITY (关键:保证 pipeline / orders 等键存在)
# ---------------------------
def ensure_state_compat(state: dict):
"""
Ensure a state dict has all required keys and sensible defaults.
This protects against old/incomplete session_state entries.
"""
roles = state.get("roles", ["retailer", "wholesaler", "distributor", "factory"])
state.setdefault("roles", roles)
state.setdefault("weeks_total", state.get("weeks_total", DEFAULT_WEEKS))
state.setdefault("week", state.get("week", 1))
# inventories/backlogs
state.setdefault("inventory", {r: INITIAL_INVENTORY for r in roles})
state.setdefault("backlog", {r: INITIAL_BACKLOG for r in roles})
# pipeline: ensure lists and proper length >= SHIPPING_LEAD_TIME
if "pipeline" not in state:
state["pipeline"] = {r: [4] * SHIPPING_LEAD_TIME for r in roles}
else:
for r in roles:
state["pipeline"].setdefault(r, [4] * SHIPPING_LEAD_TIME)
# pad if shorter than SHIPPING_LEAD_TIME
if len(state["pipeline"][r]) < SHIPPING_LEAD_TIME:
state["pipeline"][r] = state["pipeline"][r] + [4] * (SHIPPING_LEAD_TIME - len(state["pipeline"][r]))
# incoming_orders, orders_history, shipments_history
state.setdefault("incoming_orders", {r: 0 for r in roles})
state.setdefault("orders_history", {r: [] for r in roles})
state.setdefault("shipments_history", {r: [] for r in roles})
state.setdefault("logs", [])
state.setdefault("info_sharing", False)
state.setdefault("info_history_weeks", 0)
# demand
if "customer_demand" not in state:
state["customer_demand"] = make_classic_demand(state["weeks_total"])
else:
# if demand exists but wrong length, regenerate
if len(state["customer_demand"]) < state["weeks_total"]:
state["customer_demand"] = make_classic_demand(state["weeks_total"])
# ensure week in bounds
if state["week"] < 1:
state["week"] = 1
if state["week"] > state["weeks_total"] + 1:
state["week"] = state["weeks_total"] + 1
return state
# ---------------------------
# LLM call (old openai SDK)
# ---------------------------
def call_llm_for_order(role: str, local_state: dict, info_sharing_visible: bool, demand_history: list, max_tokens=40, temperature=0.7):
"""
role must be lowercase key matching state dicts (e.g., 'retailer').
Returns (order_int, raw_text)
"""
# safety: ensure pipeline/inventory keys exist
pipeline_next = local_state.get("pipeline", {}).get(role, [0])[0] if local_state.get("pipeline", {}).get(role) else 0
inventory = local_state.get("inventory", {}).get(role, 0)
backlog = local_state.get("backlog", {}).get(role, 0)
incoming_order = local_state.get("incoming_orders", {}).get(role, 0)
visible_history = demand_history if info_sharing_visible else []
# build prompt (concise)
prompt = (
f"You are the {role.title()} in a 4-player Beer Game (Retailer -> Wholesaler -> Distributor -> Factory).\n"
f"Week: {local_state.get('week')} / {local_state.get('weeks_total')}\n"
f"- Inventory: {inventory}\n"
f"- Backlog: {backlog}\n"
f"- Incoming shipment next week: {pipeline_next}\n"
f"- Incoming order this week: {incoming_order}\n"
)
if visible_history:
prompt += f"- Customer demand history (visible): {visible_history}\n"
prompt += "\nDecide a **non-negative integer** order quantity to place to your upstream supplier this week. Reply with an integer only."
try:
resp = openai.ChatCompletion.create(
model=OPENAI_MODEL,
messages=[
{"role": "system", "content": "You are an automated Beer Game agent."},
{"role": "user", "content": prompt},
],
max_tokens=max_tokens,
temperature=temperature,
n=1
)
raw = resp.choices[0].message.get("content", "").strip()
except Exception as e:
raw = f"OPENAI_ERROR: {e}"
# fallback later
# parse first integer
m = re.search(r"(-?\d+)", raw or "")
order = None
if m:
try:
order = int(m.group(1))
if order < 0:
order = 0
except:
order = None
if order is None:
# fallback heuristic
incoming = incoming_order or 0
target = INITIAL_INVENTORY + incoming
order = max(0, target - inventory)
raw = (raw + " | PARSE_FALLBACK").strip()
return int(order), raw
# ---------------------------
# GAME LOGIC (uses lowercase role keys)
# ---------------------------
def init_game(weeks=DEFAULT_WEEKS):
roles = ["retailer", "wholesaler", "distributor", "factory"]
state = {
"participant_id": None,
"week": 1,
"weeks_total": weeks,
"roles": roles,
"inventory": {r: INITIAL_INVENTORY for r in roles},
"backlog": {r: INITIAL_BACKLOG for r in roles},
"pipeline": {r: [4] * SHIPPING_LEAD_TIME for r in roles},
"incoming_orders": {r: 0 for r in roles},
"orders_history": {r: [] for r in roles},
"shipments_history": {r: [] for r in roles},
"logs": [],
"info_sharing": False,
"info_history_weeks": 0,
"customer_demand": make_classic_demand(weeks),
}
return state
def state_snapshot_for_prompt(state: dict):
# safe snapshot (keys lowercase)
return {
"week": state.get("week"),
"weeks_total": state.get("weeks_total"),
"inventory": state.get("inventory", {}).copy(),
"backlog": state.get("backlog", {}).copy(),
"incoming_orders": state.get("incoming_orders", {}).copy(),
"incoming_shipments_next_week": {r: (state.get("pipeline", {}).get(r, [0])[0] if state.get("pipeline", {}).get(r) else 0) for r in state.get("roles", [])}
}
def step_game(state: dict, distributor_order: int):
# defensive: ensure compatible keys
ensure_state_compat(state)
week = state["week"]
roles = state["roles"]
if week > state["weeks_total"]:
# already finished; do not advance further
return state
# 1) customer demand hits retailer
demand = state["customer_demand"][week - 1]
state["incoming_orders"]["retailer"] = demand
# 2) shipments arrive (front of each pipeline)
arriving = {}
for r in roles:
arr = 0
if state.get("pipeline", {}).get(r):
# pop front safely
try:
arr = state["pipeline"][r].pop(0)
except Exception:
arr = 0
state["inventory"][r] = state["inventory"].get(r, 0) + (arr or 0)
arriving[r] = arr
# 3) fulfill incoming orders (downstream -> this role)
shipments_out = {}
for r in roles:
incoming = state.get("incoming_orders", {}).get(r, 0) or 0
inv = state.get("inventory", {}).get(r, 0) or 0
shipped = min(inv, incoming)
state["inventory"][r] = inv - shipped
unfilled = incoming - shipped
if unfilled > 0:
state["backlog"][r] = state.get("backlog", {}).get(r, 0) + unfilled
shipments_out[r] = shipped
state["shipments_history"].setdefault(r, []).append(shipped)
# 4) record human distributor order
state["orders_history"]["distributor"].append(int(distributor_order))
state["incoming_orders"]["wholesaler"] = int(distributor_order)
# 5) LLM decisions
demand_history_visible = []
if state.get("info_sharing") and state.get("info_history_weeks", 0) > 0:
start_idx = max(0, (week - 1) - state["info_history_weeks"])
demand_history_visible = state["customer_demand"][start_idx:(week - 1)]
llm_outputs = {}
for role in ["retailer", "wholesaler", "factory", "distributor"]:
order_val, raw = call_llm_for_order(
role,
state_snapshot_for_prompt(state),
state.get("info_sharing", False),
demand_history_visible
)
order_val = max(0, int(order_val))
llm_outputs[role] = {"order": order_val, "raw": raw}
if role != "distributor": # AI 决策直接生效
state["orders_history"][role].append(order_val)
if role == "retailer":
state["incoming_orders"]["distributor"] = order_val
elif role == "wholesaler":
state["incoming_orders"]["factory"] = order_val
# 人类 distributor 的真实 order 后面会覆盖
state["orders_history"]["distributor"].append(int(distributor_order))
state["incoming_orders"]["wholesaler"] = int(distributor_order)
# 6) place orders into pipelines (will arrive after SHIPPING_LEAD_TIME)
downstream_map = {"factory": "wholesaler", "wholesaler": "distributor", "distributor": "retailer", "retailer": None}
for role in roles:
placed_order = state["orders_history"][role][-1] if state["orders_history"].get(role) else 0
if role == "distributor":
placed_order = int(distributor_order)
downstream = downstream_map.get(role)
if downstream:
state["pipeline"].setdefault(downstream, [0]*SHIPPING_LEAD_TIME)
state["pipeline"][downstream].append(placed_order)
# 6.5) cost calculation
if "cost" not in state:
state["cost"] = {r: 0.0 for r in roles}
for r in roles:
inv = state["inventory"].get(r, 0)
backlog = state["backlog"].get(r, 0)
inv_cost = inv * 0.5 # 每单位库存成本
back_cost = backlog * 1.0 # 每单位缺货成本
state["cost"][r] = state["cost"].get(r, 0) + inv_cost + back_cost
# 7) logging
log_entry = {
"timestamp": now_iso(),
"week": week,
"demand": demand,
"arriving": arriving,
"shipments_out": shipments_out,
"orders_submitted": {r: (state["orders_history"].get(r, [None])[-1] if state["orders_history"].get(r) else None) for r in roles},
"inventory": {r: state["inventory"].get(r, 0) for r in roles},
"backlog": {r: state["backlog"].get(r, 0) for r in roles},
"cost": {r: state["cost"].get(r, 0) for r in roles},
"info_sharing": state.get("info_sharing", False),
"info_history_weeks": state.get("info_history_weeks", 0),
"llm_raw": {k: v["raw"] for k, v in llm_outputs.items()}
}
state["logs"].append(log_entry)
# 8) advance week
state["week"] = state.get("week", 1) + 1
return state
# ---------------------------
# Persistence helpers
# ---------------------------
def save_logs_local(state: dict, participant_id: str):
df = pd.json_normalize(state.get("logs", []))
fname = LOCAL_LOG_DIR / f"logs_{participant_id}_{int(time.time())}.csv"
df.to_csv(fname, index=False)
return fname
def upload_log_to_hf_at_end(local_file: Path, participant_id: str):
"""
Only call this at the end of the game to upload final CSV to HF dataset.
"""
if not HF_TOKEN or not HF_REPO_ID:
return None
dest = f"logs/{participant_id}/{local_file.name}"
try:
upload_file(path_or_fileobj=str(local_file), path_in_repo=dest, repo_id=HF_REPO_ID, repo_type="dataset", token=HF_TOKEN)
return f"https://huggingface.co/datasets/{HF_REPO_ID}/resolve/main/{dest}"
except Exception as e:
st.error(f"HF upload failed: {e}")
return None
# ---------------------------
# STREAMLIT UI & session mgmt
# ---------------------------
st.set_page_config(page_title="Beer Game (Distributor + LLMs)", layout="wide")
st.title("🍺 Beer Game — Human Distributor vs LLM agents")
# participant id via query param or input
qp = st.query_params
pid_from_q = qp.get("participant_id", [None])[0] if qp else None
pid_input = st.text_input("Participant ID (leave blank to auto-generate or use ?participant_id=ID)", value=pid_from_q or "")
participant_id = pid_input.strip() if pid_input else st.session_state.setdefault("auto_pid", str(uuid.uuid4())[:8])
st.sidebar.markdown(f"**Participant ID:** `{participant_id}`")
# sessions container
if "sessions" not in st.session_state:
st.session_state["sessions"] = {}
# reset button for debugging / clearing old sessions
if st.sidebar.button("Reset session (clear saved state)"):
if participant_id in st.session_state["sessions"]:
del st.session_state["sessions"][participant_id]
st.experimental_rerun()
# create or ensure session state
if participant_id not in st.session_state["sessions"]:
st.session_state["sessions"][participant_id] = init_game(DEFAULT_WEEKS)
st.session_state["sessions"][participant_id]["participant_id"] = participant_id
# retrieve and ensure compatibility immediately
state = st.session_state["sessions"][participant_id]
state = ensure_state_compat(state)
st.session_state["sessions"][participant_id] = state # write back
# sidebar controls
st.sidebar.header("Experiment controls")
state["info_sharing"] = st.sidebar.checkbox("Enable Information Sharing (share demand)", value=state.get("info_sharing", False))
state["info_history_weeks"] = st.sidebar.slider("Weeks of demand history to share (0 = none)", 0, 8, value=state.get("info_history_weeks", 0))
st.sidebar.markdown("---")
st.sidebar.write("Model for LLM agents:")
st.sidebar.write(OPENAI_MODEL)
st.sidebar.markdown("---")
st.sidebar.write("HF upload settings:")
st.sidebar.write(f"- HF_REPO_ID: {HF_REPO_ID or 'NOT SET'}")
st.sidebar.write(f"- HF_TOKEN: {'SET' if HF_TOKEN else 'NOT SET'}")
# main UI
col_main, col_side = st.columns([3,1])
with col_main:
st.header(f"Week {state['week']} / {state['weeks_total']}")
demand_display = state["customer_demand"][state["week"] - 1] if 0 <= (state["week"] - 1) < len(state["customer_demand"]) else None
st.subheader(f"Customer demand (retailer receives this week): {demand_display}")
# role panels
roles = state["roles"]
panels = st.columns(len(roles))
for i, role in enumerate(roles):
with panels[i]:
st.markdown(f"### {role.title()}")
st.metric("Inventory", state["inventory"].get(role, 0))
st.metric("Backlog", state["backlog"].get(role, 0))
incoming = state["incoming_orders"].get(role, 0)
st.write(f"Incoming order (this week): **{incoming}**")
next_ship = state["pipeline"].get(role, [0])[0] if state["pipeline"].get(role) else 0
st.write(f"Incoming shipment next week: **{next_ship}**")
st.markdown("---")
# Distributor form
with st.form(key=f"order_form_{participant_id}", clear_on_submit=False):
st.write("### Your (Distributor) decision this week")
# 如果有 LLM 给的建议,就显示出来
last_log = state["logs"][-1] if state.get("logs") else None
if last_log and "llm_raw" in last_log and "distributor" in last_log["llm_raw"]:
suggestion = last_log["llm_raw"]["distributor"]
st.info(f"💡 AI suggests you order: **{suggestion}** units (you can follow or override)")
else:
st.info("💡 AI suggestion will appear after the first processed week.")
default_val = state["incoming_orders"].get("distributor", 4) or 4
distributor_order = st.number_input("Order to place to upstream (Wholesaler):", min_value=0, step=1, value=default_val)
submitted = st.form_submit_button("Submit Order (locks your decision)")
if submitted:
st.session_state.setdefault("pending_orders", {})
st.session_state["pending_orders"][participant_id] = int(distributor_order)
st.success(f"Order submitted: {distributor_order}. Now click 'Next Week' to process the week.")
st.markdown("---")
pending = st.session_state.get("pending_orders", {}).get(participant_id, None)
if pending is None:
st.info("Please submit your order first to enable Next Week processing.")
else:
if st.button("Next Week — process week and invoke LLM agents"):
# Guard: don't step if game finished
if state["week"] > state["weeks_total"]:
st.info("Game already finished for this participant.")
else:
try:
state = step_game(state, pending)
# write back
st.session_state["sessions"][participant_id] = state
# remove pending
del st.session_state["pending_orders"][participant_id]
st.success(f"Week processed. Advanced to week {state['week']}.")
except Exception as e:
# show traceback for debugging
tb = traceback.format_exc()
st.error(f"Error during Next Week processing: {e}")
st.text_area("Traceback", tb, height=300)
st.markdown("### Recent logs")
if state.get("logs"):
df = pd.json_normalize(state["logs"][-6:])
st.dataframe(df, use_container_width=True)
else:
st.write("No logs yet. Submit your first order and press Next Week.")
with col_side:
st.subheader("Information Sharing (preview)")
st.write(f"Sharing {state.get('info_history_weeks', 0)} weeks of history (0 = only current).")
if state.get("info_sharing"):
h = state.get("info_history_weeks", 0)
start = max(0, (state["week"] - 1) - h)
hist = state["customer_demand"][start: state["week"]]
st.write("Demand visible to agents:", hist)
st.markdown("---")
st.subheader("Admin / Debug")
if st.button("Test LLM connection"):
if not openai.api_key:
st.error("OpenAI API key missing (set OPENAI_API_KEY in secrets).")
else:
try:
test_prompt = "Reply with 42."
resp = openai.ChatCompletion.create(model=OPENAI_MODEL, messages=[{"role":"user","content":test_prompt}], max_tokens=10)
st.write("LLM raw:", resp.choices[0].message.get("content"))
except Exception as e:
st.error(f"LLM test failed: {e}")
if st.button("Save logs now (manual)"):
if not state.get("logs"):
st.info("No logs to save.")
else:
local_file = save_logs_local(state, participant_id)
st.success(f"Saved local file: {local_file}")
# ---------------------------
# End-of-game upload (only when finished)
# ---------------------------
# Note: check strictly greater than weeks_total (we advanced after final week)
if state.get("week", 1) > state.get("weeks_total", DEFAULT_WEEKS):
st.success("Game completed for this participant.")
final_csv = save_logs_local(state, participant_id)
with open(final_csv, "rb") as f:
st.download_button("Download final logs CSV", data=f, file_name=final_csv.name, mime="text/csv")
if HF_TOKEN and HF_REPO_ID:
url = upload_log_to_hf_at_end(final_csv, participant_id)
if url:
st.write(f"Final logs uploaded to HF Hub: {url}")