Ben commited on
Commit
b13b7c1
·
1 Parent(s): c533f9b

Show video output only

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -43,7 +43,7 @@ def simulate_agent(stage_selection):
43
  try:
44
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
45
  except Exception as e:
46
- return None, f"Weight download failed. Error: {str(e)}"
47
 
48
  # Initialize env
49
  env = gym.make("LunarLander-v3", render_mode="rgb_array")
@@ -58,12 +58,11 @@ def simulate_agent(stage_selection):
58
  actor.eval()
59
  except Exception as e:
60
  env.close()
61
- return None, f"Architecture mismatch. Error: {str(e)}"
62
 
63
  state, _ = env.reset(seed=32)
64
  done = False
65
  frames = []
66
- total_reward = 0.0
67
  step_count = 0
68
 
69
  while not done and step_count < 600:
@@ -73,15 +72,14 @@ def simulate_agent(stage_selection):
73
  frames.append(frame)
74
  except Exception as e:
75
  env.close()
76
- return None, f"Render failed: {str(e)}"
77
 
78
  state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
79
  with torch.no_grad():
80
  action_logits = actor(state_tensor)
81
  action = torch.argmax(action_logits, dim=1).item()
82
 
83
- state, reward, terminated, truncated, _ = env.step(action)
84
- total_reward += reward
85
  step_count += 1
86
  done = terminated or truncated
87
 
@@ -93,14 +91,9 @@ def simulate_agent(stage_selection):
93
  try:
94
  imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p')
95
  except Exception as e:
96
- return None, f"Video encoding failed: {str(e)}"
97
 
98
- logs = (f"Status: Inference complete\n"
99
- f"Stage: {stage_selection}\n"
100
- f"Total Reward: {total_reward:.2f}\n"
101
- f"Steps: {step_count}")
102
-
103
- return video_filename, logs
104
 
105
  # 3. Gradio Web UI
106
  with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
@@ -122,13 +115,12 @@ with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as d
122
  run_button = gr.Button("Run Inference", variant="primary")
123
 
124
  with gr.Column(scale=2):
125
- video_output = gr.Video(label="Environment Render")
126
- text_output = gr.Textbox(label="Execution Logs", lines=4)
127
 
128
  run_button.click(
129
  fn=simulate_agent,
130
  inputs=[model_dropdown],
131
- outputs=[video_output, text_output]
132
  )
133
 
134
  if __name__ == "__main__":
 
43
  try:
44
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
45
  except Exception as e:
46
+ raise gr.Error(f"Weight download failed. Error: {str(e)}")
47
 
48
  # Initialize env
49
  env = gym.make("LunarLander-v3", render_mode="rgb_array")
 
58
  actor.eval()
59
  except Exception as e:
60
  env.close()
61
+ raise gr.Error(f"Architecture mismatch. Error: {str(e)}")
62
 
63
  state, _ = env.reset(seed=32)
64
  done = False
65
  frames = []
 
66
  step_count = 0
67
 
68
  while not done and step_count < 600:
 
72
  frames.append(frame)
73
  except Exception as e:
74
  env.close()
75
+ raise gr.Error(f"Render failed: {str(e)}")
76
 
77
  state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
78
  with torch.no_grad():
79
  action_logits = actor(state_tensor)
80
  action = torch.argmax(action_logits, dim=1).item()
81
 
82
+ state, _, terminated, truncated, _ = env.step(action)
 
83
  step_count += 1
84
  done = terminated or truncated
85
 
 
91
  try:
92
  imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p')
93
  except Exception as e:
94
+ raise gr.Error(f"Video encoding failed: {str(e)}")
95
 
96
+ return video_filename
 
 
 
 
 
97
 
98
  # 3. Gradio Web UI
99
  with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
 
115
  run_button = gr.Button("Run Inference", variant="primary")
116
 
117
  with gr.Column(scale=2):
118
+ video_output = gr.Video(label="Environment Render", autoplay=True)
 
119
 
120
  run_button.click(
121
  fn=simulate_agent,
122
  inputs=[model_dropdown],
123
+ outputs=[video_output]
124
  )
125
 
126
  if __name__ == "__main__":