Kaushik Rajan commited on
Commit
7102d41
·
1 Parent(s): ce03581

Fix: Correct Gradio callback logic to prevent crash

Browse files
Files changed (1) hide show
  1. app.py +33 -28
app.py CHANGED
@@ -278,17 +278,25 @@ def create_interface():
278
  @spaces.GPU
279
  def game_step_and_update(env, mode, rd_raw, mkt_raw, sales_raw, rd_pct, mkt_pct, sales_pct):
280
  player_budget = env.player_stats["budget"]
281
-
 
 
 
 
 
 
 
 
 
 
282
  if mode == "Percentages":
283
  if rd_pct + mkt_pct + sales_pct != 100:
284
- status_text = "Error: Percentage allocations must sum to 100%."
285
- return env, status_text, env.ai_stats.get("last_reasoning", ""), *create_plots(env.history), gr.Label(f"Your Budget: ${player_budget}"), rd_slider_raw, mkt_slider_raw, sales_slider_raw, rd_slider_pct, mkt_slider_pct, sales_slider_pct
286
 
287
  rd_alloc_val = int(player_budget * rd_pct / 100)
288
  mkt_alloc_val = int(player_budget * mkt_pct / 100)
289
  sales_alloc_val = int(player_budget * sales_pct / 100)
290
 
291
- # Distribute rounding errors
292
  total = rd_alloc_val + mkt_alloc_val + sales_alloc_val
293
  sales_alloc_val += player_budget - total
294
 
@@ -296,40 +304,36 @@ def create_interface():
296
  rd_alloc_val, mkt_alloc_val, sales_alloc_val = rd_raw, mkt_raw, sales_raw
297
 
298
  if (rd_alloc_val + mkt_alloc_val + sales_alloc_val) > player_budget:
299
- status_text = f"Error: Allocation (${rd_alloc_val + mkt_alloc_val + sales_alloc_val}) exceeds budget (${player_budget})."
300
- # This part needs to return updates for all sliders to avoid errors
301
- return (env, status_text, env.ai_stats.get("last_reasoning", ""), *create_plots(env.history),
302
- gr.Label(f"Your Budget: ${player_budget}"),
303
- gr.Slider(maximum=player_budget), gr.Slider(maximum=player_budget), gr.Slider(maximum=player_budget),
304
- rd_slider_pct, mkt_slider_pct, sales_slider_pct)
305
-
306
 
307
  player_alloc = {"rd": rd_alloc_val, "marketing": mkt_alloc_val, "sales": sales_alloc_val}
308
  ai_alloc, ai_reasoning = ai_strategy(env.ai_stats, env.player_stats)
309
- env.ai_stats["last_reasoning"] = ai_reasoning # Store reasoning for error case
310
 
311
  env.step(player_alloc, ai_alloc)
312
  state = env.get_state()
313
 
314
  plots = create_plots(state["history"])
315
 
 
316
  if state["game_over"]:
317
  winner = env.get_winner()
318
  status_text = f"Game Over! Winner: {winner}. Final market share: You ({state['player_stats']['market_share']:.1f}%) vs AI ({state['ai_stats']['market_share']:.1f}%)."
319
- submit_btn.interactive = False
320
  else:
321
  status_text = f"End of Quarter {state['quarter']}. Your turn."
322
 
323
  new_budget = state["player_stats"]["budget"]
324
 
325
- # Return updates for all sliders
326
- return (state, status_text, ai_reasoning, *plots,
327
- gr.Label(f"Your Budget: ${new_budget}"),
328
- gr.Slider(maximum=new_budget, value=int(new_budget/3)),
329
- gr.Slider(maximum=new_budget, value=int(new_budget/3)),
330
- gr.Slider(maximum=new_budget, value=new_budget - 2 * int(new_budget/3)),
331
- gr.Slider(value=33), gr.Slider(value=33), gr.Slider(value=34)
332
- )
 
333
 
334
  def on_new_game():
335
  env = BusinessCompetitionEnv()
@@ -337,12 +341,12 @@ def create_interface():
337
  plots = create_plots(state["history"])
338
  return (
339
  env, f"Quarter 1 of {NUM_QUARTERS}. Your move.", "", *plots,
340
- gr.Label(f"Your Budget: ${INITIAL_BUDGET}"),
341
- gr.Slider(maximum=INITIAL_BUDGET, value=333),
342
- gr.Slider(maximum=INITIAL_BUDGET, value=333),
343
- gr.Slider(maximum=INITIAL_BUDGET, value=334),
344
- gr.Slider(value=33), gr.Slider(value=33), gr.Slider(value=34),
345
- gr.Button(interactive=True)
346
  )
347
 
348
  def update_total_raw_display(rd, mkt, sales):
@@ -363,7 +367,8 @@ def create_interface():
363
  plot_market_share, plot_budget, plot_quality,
364
  player_budget_display,
365
  rd_slider_raw, mkt_slider_raw, sales_slider_raw,
366
- rd_slider_pct, mkt_slider_pct, sales_slider_pct
 
367
  ]
368
  )
369
 
 
278
  @spaces.GPU
279
  def game_step_and_update(env, mode, rd_raw, mkt_raw, sales_raw, rd_pct, mkt_pct, sales_pct):
280
  player_budget = env.player_stats["budget"]
281
+
282
+ # Helper to create a return tuple for user input errors
283
+ def create_error_return(status_text):
284
+ return (
285
+ env, status_text, env.ai_stats.get("last_reasoning", ""), *create_plots(env.history),
286
+ gr.update(value=f"Your Budget: ${player_budget}"),
287
+ gr.update(), gr.update(), gr.update(), # Raw sliders
288
+ gr.update(), gr.update(), gr.update(), # Pct sliders
289
+ gr.update(interactive=True) # Submit button
290
+ )
291
+
292
  if mode == "Percentages":
293
  if rd_pct + mkt_pct + sales_pct != 100:
294
+ return create_error_return("Error: Percentage allocations must sum to 100%.")
 
295
 
296
  rd_alloc_val = int(player_budget * rd_pct / 100)
297
  mkt_alloc_val = int(player_budget * mkt_pct / 100)
298
  sales_alloc_val = int(player_budget * sales_pct / 100)
299
 
 
300
  total = rd_alloc_val + mkt_alloc_val + sales_alloc_val
301
  sales_alloc_val += player_budget - total
302
 
 
304
  rd_alloc_val, mkt_alloc_val, sales_alloc_val = rd_raw, mkt_raw, sales_raw
305
 
306
  if (rd_alloc_val + mkt_alloc_val + sales_alloc_val) > player_budget:
307
+ return create_error_return(f"Error: Allocation (${rd_alloc_val + mkt_alloc_val + sales_alloc_val}) exceeds budget (${player_budget}).")
 
 
 
 
 
 
308
 
309
  player_alloc = {"rd": rd_alloc_val, "marketing": mkt_alloc_val, "sales": sales_alloc_val}
310
  ai_alloc, ai_reasoning = ai_strategy(env.ai_stats, env.player_stats)
311
+ env.ai_stats["last_reasoning"] = ai_reasoning
312
 
313
  env.step(player_alloc, ai_alloc)
314
  state = env.get_state()
315
 
316
  plots = create_plots(state["history"])
317
 
318
+ submit_btn_update = gr.update(interactive=True)
319
  if state["game_over"]:
320
  winner = env.get_winner()
321
  status_text = f"Game Over! Winner: {winner}. Final market share: You ({state['player_stats']['market_share']:.1f}%) vs AI ({state['ai_stats']['market_share']:.1f}%)."
322
+ submit_btn_update = gr.update(interactive=False)
323
  else:
324
  status_text = f"End of Quarter {state['quarter']}. Your turn."
325
 
326
  new_budget = state["player_stats"]["budget"]
327
 
328
+ return (
329
+ state, status_text, ai_reasoning, *plots,
330
+ gr.update(value=f"Your Budget: ${new_budget}"),
331
+ gr.update(maximum=new_budget, value=int(new_budget/3)),
332
+ gr.update(maximum=new_budget, value=int(new_budget/3)),
333
+ gr.update(maximum=new_budget, value=new_budget - 2 * int(new_budget/3)),
334
+ gr.update(value=33), gr.update(value=33), gr.update(value=34),
335
+ submit_btn_update
336
+ )
337
 
338
  def on_new_game():
339
  env = BusinessCompetitionEnv()
 
341
  plots = create_plots(state["history"])
342
  return (
343
  env, f"Quarter 1 of {NUM_QUARTERS}. Your move.", "", *plots,
344
+ gr.update(value=f"Your Budget: ${INITIAL_BUDGET}"),
345
+ gr.update(maximum=INITIAL_BUDGET, value=333),
346
+ gr.update(maximum=INITIAL_BUDGET, value=333),
347
+ gr.update(maximum=INITIAL_BUDGET, value=334),
348
+ gr.update(value=33), gr.update(value=33), gr.update(value=34),
349
+ gr.update(interactive=True)
350
  )
351
 
352
  def update_total_raw_display(rd, mkt, sales):
 
367
  plot_market_share, plot_budget, plot_quality,
368
  player_budget_display,
369
  rd_slider_raw, mkt_slider_raw, sales_slider_raw,
370
+ rd_slider_pct, mkt_slider_pct, sales_slider_pct,
371
+ submit_btn
372
  ]
373
  )
374