| |
| """ |
| Real Perfusion Monitoring System - Hugging Face Spaces Deployment |
| Online DQN Agent Evaluation with Real-Time Trajectory Plotting |
| """ |
|
|
| import gradio as gr |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import time |
| import io |
| import base64 |
| from datetime import datetime |
| import threading |
| import queue |
| import os |
| import sys |
|
|
| |
| sys.path.append(os.path.dirname(__file__)) |
|
|
| |
| try: |
| |
| import config |
| import init |
| from operations import single_step |
| from dqn_new_system import NewSimulationEnv, load_agent |
| SIMULATION_AVAILABLE = True |
| print("Using full simulation system") |
| except ImportError as e: |
| print(f"Full simulation not available, using demo: {e}") |
| try: |
| |
| import config_demo as config |
| from demo_simulation import NewSimulationEnv, load_agent |
| SIMULATION_AVAILABLE = True |
| print("Using demo simulation system") |
| except ImportError as e2: |
| print(f"Demo simulation also not available: {e2}") |
| SIMULATION_AVAILABLE = False |
|
|
| |
| class SimulationState: |
| def __init__(self): |
| self.running = False |
| self.agent = None |
| self.env = None |
| self.trajectory_data = { |
| 'hours': [], |
| 'parameters': {}, |
| 'actions': [], |
| 'rewards': [], |
| 'scenario': None, |
| 'param_names': [], |
| 'param_indices': [] |
| } |
| self.messages = [] |
| self.current_hour = 0 |
| self.total_reward = 0 |
| self.message_queue = queue.Queue() |
|
|
| |
| sim_state = SimulationState() |
|
|
| def get_thresholds(scenario, param_idx): |
| """Get threshold values for plotting safety zones""" |
| if not SIMULATION_AVAILABLE: |
| return None |
| try: |
| if param_idx < len(config.criticalDepletion): |
| return [ |
| config.criticalDepletion[param_idx], |
| config.depletion[param_idx], |
| config.excess[param_idx], |
| config.criticalExcess[param_idx] |
| ] |
| except: |
| pass |
| return None |
|
|
| def generate_trajectory_plot(): |
| """Generate trajectory plot for current simulation data""" |
| global sim_state |
| |
| if not sim_state.trajectory_data['hours'] or not sim_state.trajectory_data['parameters']: |
| |
| fig, ax = plt.subplots(figsize=(12, 8)) |
| ax.text(0.5, 0.5, 'π₯ Real-Time Parameter Trajectories\n\nStart simulation to see DQN agent performance\nwith live parameter evolution', |
| ha='center', va='center', fontsize=14, color='#666') |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| ax.axis('off') |
| return fig |
| |
| try: |
| |
| fig, axes = plt.subplots(2, 3, figsize=(16, 10)) |
| axes = axes.flatten() |
| |
| |
| agent_color = '#2E86DE' |
| critical_color = '#E74C3C' |
| warning_color = '#F39C12' |
| safe_zone_color = '#D5F4E6' |
| warning_zone_color = '#FCF3CF' |
| danger_zone_color = '#FADBD8' |
| |
| hours = sim_state.trajectory_data['hours'] |
| scenario = sim_state.trajectory_data['scenario'] |
| param_names = sim_state.trajectory_data['param_names'] |
| param_indices = sim_state.trajectory_data['param_indices'] |
| |
| for i, (param_name, param_idx) in enumerate(zip(param_names, param_indices)): |
| if i < len(axes) and param_name in sim_state.trajectory_data['parameters']: |
| ax = axes[i] |
| values = sim_state.trajectory_data['parameters'][param_name] |
| |
| |
| thresholds = get_thresholds(scenario, param_idx) |
| |
| if thresholds and len(values) > 0: |
| critical_low, warning_low, warning_high, critical_high = thresholds |
| |
| |
| y_min = min(critical_low * 0.9, min(values) * 0.95) |
| y_max = max(critical_high * 1.1, max(values) * 1.05) |
| |
| |
| ax.axhspan(y_min, critical_low, alpha=0.15, color=danger_zone_color, zorder=0) |
| ax.axhspan(critical_high, y_max, alpha=0.15, color=danger_zone_color, zorder=0) |
| ax.axhspan(critical_low, warning_low, alpha=0.1, color=warning_zone_color, zorder=0) |
| ax.axhspan(warning_high, critical_high, alpha=0.1, color=warning_zone_color, zorder=0) |
| ax.axhspan(warning_low, warning_high, alpha=0.12, color=safe_zone_color, zorder=0) |
| |
| |
| ax.axhline(y=critical_low, color=critical_color, linestyle='--', linewidth=2, alpha=0.8) |
| ax.axhline(y=critical_high, color=critical_color, linestyle='--', linewidth=2, alpha=0.8) |
| ax.axhline(y=warning_low, color=warning_color, linestyle=':', linewidth=1.5, alpha=0.7) |
| ax.axhline(y=warning_high, color=warning_color, linestyle=':', linewidth=1.5, alpha=0.7) |
| |
| |
| if len(hours) > 1: |
| ax.plot(hours, values, color=agent_color, linewidth=3, |
| marker='o', markersize=6, markerfacecolor='white', |
| markeredgewidth=2, markeredgecolor=agent_color, |
| label='DQN Agent', zorder=4) |
| elif len(hours) == 1: |
| ax.plot(hours[0], values[0], color=agent_color, marker='o', |
| markersize=8, markerfacecolor='white', markeredgewidth=2, |
| markeredgecolor=agent_color, zorder=4) |
| |
| |
| ax.set_title(f'{param_name}', fontsize=12, fontweight='bold') |
| ax.set_xlabel('Time (hours)', fontsize=10) |
| ax.set_ylabel('Value', fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| |
| if len(values) > 0: |
| if thresholds: |
| ax.set_ylim(y_min, y_max) |
| else: |
| margin = (max(values) - min(values)) * 0.1 if len(values) > 1 else 1 |
| ax.set_ylim(min(values) - margin, max(values) + margin) |
| |
| ax.set_xlim(0, max(24, max(hours) + 1) if hours else 24) |
| |
| |
| for i in range(len(param_names), len(axes)): |
| axes[i].set_visible(False) |
| |
| |
| current_hour = max(hours) if hours else 0 |
| fig.suptitle(f'{scenario} DQN Agent Performance - Hour {current_hour}/24', |
| fontsize=14, fontweight='bold', color='#2C3E50') |
| |
| plt.tight_layout() |
| return fig |
| |
| except Exception as e: |
| print(f"Error generating plot: {e}") |
| fig, ax = plt.subplots(figsize=(12, 8)) |
| ax.text(0.5, 0.5, f'Error generating plot: {str(e)}', |
| ha='center', va='center', fontsize=12, color='red') |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| ax.axis('off') |
| return fig |
|
|
| def format_messages(): |
| """Format messages for display""" |
| global sim_state |
| |
| if not sim_state.messages: |
| return "π€ **Welcome to Real Perfusion Monitoring System!**\n\nSelect a scenario and click 'Start DQN Evaluation' to begin monitoring real AI-controlled perfusion.\n\nπ You'll see:\nβ’ Real-time parameter trajectories\nβ’ Hour-by-hour AI decisions\nβ’ Critical alerts and warnings\nβ’ Complete 24-hour simulation results" |
| |
| formatted_messages = [] |
| for msg in sim_state.messages[-20:]: |
| timestamp = msg.get('timestamp', '') |
| message = msg.get('message', '') |
| msg_type = msg.get('type', 'info') |
| |
| |
| emoji_map = { |
| 'system': 'π₯', |
| 'parameter': 'π', |
| 'action': 'π―', |
| 'info': 'π‘', |
| 'success': 'π', |
| 'error': 'β', |
| 'warning': 'β οΈ' |
| } |
| |
| emoji = emoji_map.get(msg_type, 'π') |
| formatted_messages.append(f"{emoji} **[{timestamp}]** {message}") |
| |
| return "\n\n".join(formatted_messages) |
|
|
| def start_simulation(scenario): |
| """Start DQN evaluation simulation""" |
| global sim_state |
| |
| if not SIMULATION_AVAILABLE: |
| return "β **Error**: Simulation modules not available in this environment.", generate_trajectory_plot(), "Scenario: Not Available | Status: Error | Hour: 0 | Reward: 0" |
| |
| if sim_state.running: |
| return "β οΈ **Warning**: Simulation already running!", generate_trajectory_plot(), f"Scenario: {sim_state.trajectory_data['scenario']} | Status: Running | Hour: {sim_state.current_hour} | Reward: {sim_state.total_reward:.1f}" |
| |
| try: |
| |
| sim_state.running = True |
| sim_state.messages = [] |
| sim_state.current_hour = 0 |
| sim_state.total_reward = 0 |
| sim_state.trajectory_data = { |
| 'hours': [], |
| 'parameters': {}, |
| 'actions': [], |
| 'rewards': [], |
| 'scenario': scenario, |
| 'param_names': [], |
| 'param_indices': [] |
| } |
| |
| |
| sim_state.env = NewSimulationEnv(scenario=scenario) |
| |
| |
| try: |
| |
| output_dir = "./New_System_Results" |
| if not os.path.exists(output_dir): |
| output_dir = "." |
| |
| best_agent_path = os.path.join(output_dir, f'best_dqn_agent_{scenario}.pth') |
| final_agent_path = os.path.join(output_dir, f'final_dqn_agent_{scenario}.pth') |
| |
| if os.path.exists(best_agent_path): |
| sim_state.agent = load_agent(best_agent_path) |
| elif os.path.exists(final_agent_path): |
| sim_state.agent = load_agent(final_agent_path) |
| else: |
| |
| print(f"No trained model found, using demo agent for {scenario}") |
| sim_state.agent = load_agent("demo") |
| |
| except Exception as agent_error: |
| print(f"Agent loading error: {agent_error}, falling back to demo") |
| sim_state.agent = load_agent("demo") |
| |
| |
| sim_state.messages.append({ |
| 'type': 'system', |
| 'message': f'π₯ **Starting Real {scenario} DQN Evaluation**', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| |
| threading.Thread(target=run_simulation_thread, args=(scenario,), daemon=True).start() |
| |
| return format_messages(), generate_trajectory_plot(), f"Scenario: {scenario} | Status: Starting | Hour: 0 | Reward: 0" |
| |
| except Exception as e: |
| sim_state.running = False |
| error_msg = f"β **Error starting simulation**: {str(e)}" |
| sim_state.messages.append({ |
| 'type': 'error', |
| 'message': error_msg, |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| return format_messages(), generate_trajectory_plot(), "Scenario: Error | Status: Failed | Hour: 0 | Reward: 0" |
|
|
| def stop_simulation(): |
| """Stop the current simulation""" |
| global sim_state |
| |
| if not sim_state.running: |
| return format_messages(), generate_trajectory_plot(), f"Scenario: {sim_state.trajectory_data.get('scenario', 'None')} | Status: Not Running | Hour: {sim_state.current_hour} | Reward: {sim_state.total_reward:.1f}" |
| |
| sim_state.running = False |
| sim_state.messages.append({ |
| 'type': 'warning', |
| 'message': 'βΉοΈ **Simulation stopped by user**', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| return format_messages(), generate_trajectory_plot(), f"Scenario: {sim_state.trajectory_data['scenario']} | Status: Stopped | Hour: {sim_state.current_hour} | Reward: {sim_state.total_reward:.1f}" |
|
|
| def run_simulation_thread(scenario): |
| """Run simulation in background thread""" |
| global sim_state |
| |
| try: |
| |
| if scenario == "EYE": |
| param_names = ["Temperature", "VR", "pH", "pvO2", "Glucose", "Insulin"] |
| param_indices = [0, 3, 4, 6, 9, 10] |
| else: |
| param_names = ["Temperature", "VR", "pH", "pvO2", "Glucose", "Insulin"] |
| param_indices = [0, 3, 4, 6, 9, 10] |
| |
| action_names = ["Temp", "Press", "FiO2", "Glucose", "Insulin", "Bicarb", "Vasodil", "Dial_In", "Dial_Out"] |
| |
| |
| sim_state.trajectory_data['param_names'] = param_names |
| sim_state.trajectory_data['param_indices'] = param_indices |
| for param_name in param_names: |
| sim_state.trajectory_data['parameters'][param_name] = [] |
| |
| |
| state = sim_state.env.reset() |
| |
| |
| sim_state.trajectory_data['hours'].append(0) |
| sim_state.trajectory_data['rewards'].append(0) |
| |
| |
| for i, (param_name, param_idx) in enumerate(zip(param_names, param_indices)): |
| value = sim_state.env.big_state[param_idx] |
| sim_state.trajectory_data['parameters'][param_name].append(value) |
| |
| sim_state.messages.append({ |
| 'type': 'system', |
| 'message': f'π **Initial {scenario} Parameters Recorded**', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| |
| sim_state.agent.policy_net.eval() |
| original_epsilon = sim_state.agent.epsilon |
| sim_state.agent.epsilon = 0.0 |
| |
| |
| total_reward = 0 |
| step_count = 0 |
| max_steps = 24 |
| |
| done = False |
| while not done and step_count < max_steps and sim_state.running: |
| time.sleep(3) |
| |
| if not sim_state.running: |
| break |
| |
| |
| action = sim_state.agent.choose_action(state) |
| action_decoded = sim_state.env.decode_action(action) |
| |
| |
| next_state, reward, done, info = sim_state.env.step(action, train=False) |
| |
| step_count += 1 |
| total_reward += reward |
| hours_survived = info.get("hours_survived", step_count) |
| |
| sim_state.current_hour = int(hours_survived) |
| sim_state.total_reward = total_reward |
| |
| |
| sim_state.trajectory_data['hours'].append(int(hours_survived)) |
| sim_state.trajectory_data['rewards'].append(total_reward) |
| sim_state.trajectory_data['actions'].append(action_decoded.copy()) |
| |
| |
| sim_state.messages.append({ |
| 'type': 'system', |
| 'message': f'β° **Hour {int(hours_survived)}** - DQN Agent Decision Made', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| |
| active_actions = [] |
| for i, (action_name, action_value) in enumerate(zip(action_names, action_decoded)): |
| if i < len(action_decoded) and action_value != 0: |
| action_desc = "increase" if action_value == 1 else "decrease" |
| active_actions.append(f"{action_name}: {action_desc}") |
| |
| if active_actions: |
| sim_state.messages.append({ |
| 'type': 'action', |
| 'message': f'π― **DQN Actions**: {", ".join(active_actions)}', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| else: |
| sim_state.messages.append({ |
| 'type': 'action', |
| 'message': 'π― **DQN Decision**: Maintain all parameters', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| |
| for i, (param_name, param_idx) in enumerate(zip(param_names, param_indices)): |
| value = sim_state.env.big_state[param_idx] |
| sim_state.trajectory_data['parameters'][param_name].append(value) |
| |
| |
| status = "" |
| if SIMULATION_AVAILABLE and param_idx < len(config.criticalDepletion): |
| if value <= config.criticalDepletion[param_idx] or value >= config.criticalExcess[param_idx]: |
| status = " β οΈ CRITICAL" |
| elif value <= config.depletion[param_idx] or value >= config.excess[param_idx]: |
| status = " β οΈ Warning" |
| |
| |
| if reward != 0: |
| sim_state.messages.append({ |
| 'type': 'info', |
| 'message': f'π° **Reward**: {reward:.1f} (Total: {total_reward:.1f})', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| state = next_state |
| |
| if done: |
| if hours_survived >= 24: |
| sim_state.messages.append({ |
| 'type': 'success', |
| 'message': f'π **SUCCESS!** {scenario} perfusion completed! Survived {hours_survived:.1f} hours.', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| else: |
| sim_state.messages.append({ |
| 'type': 'error', |
| 'message': f'π **Early Termination** - {scenario} ended at {hours_survived:.1f} hours. Total reward: {total_reward:.1f}', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| break |
| |
| |
| sim_state.agent.epsilon = original_epsilon |
| |
| |
| sim_state.messages.append({ |
| 'type': 'system', |
| 'message': f'π **Evaluation Complete** - Duration: {hours_survived:.1f}h | Reward: {total_reward:.1f} | Status: {"Success" if hours_survived >= 24 else "Early termination"}', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| except Exception as e: |
| sim_state.messages.append({ |
| 'type': 'error', |
| 'message': f'β **Simulation Error**: {str(e)}', |
| 'timestamp': datetime.now().strftime("%H:%M:%S") |
| }) |
| |
| finally: |
| sim_state.running = False |
|
|
| def get_live_updates(): |
| """Get live updates for the interface""" |
| return format_messages(), generate_trajectory_plot(), f"Scenario: {sim_state.trajectory_data.get('scenario', 'None')} | Status: {'Running' if sim_state.running else 'Stopped'} | Hour: {sim_state.current_hour} | Reward: {sim_state.total_reward:.1f}" |
|
|
| |
| with gr.Blocks(title="Real Perfusion Monitoring System", theme=gr.themes.Soft()) as iface: |
| gr.HTML(""" |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; margin-bottom: 20px;"> |
| <h1 style="color: white; margin: 0; font-size: 2rem;">π₯ Real Perfusion Monitoring System</h1> |
| <p style="color: rgba(255,255,255,0.9); margin: 10px 0 0 0; font-size: 1.1rem;">Live DQN Agent Evaluation with Real-Time Trajectory Plotting</p> |
| </div> |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| |
| plot_output = gr.Plot(label="π Real-Time Parameter Trajectories", |
| value=generate_trajectory_plot()) |
| |
| with gr.Column(scale=1): |
| |
| gr.HTML("<h3>βοΈ DQN Control Panel</h3>") |
| |
| status_display = gr.HTML("Status: Ready") |
| |
| scenario_input = gr.Dropdown( |
| choices=["EYE", "VCA"], |
| value="EYE", |
| label="Perfusion Scenario" |
| ) |
| |
| with gr.Row(): |
| start_btn = gr.Button("π Start DQN Evaluation", variant="primary") |
| stop_btn = gr.Button("βΉοΈ Stop", variant="secondary") |
| |
| gr.HTML(""" |
| <div style="margin: 15px 0; padding: 10px; background: #f8f9fa; border-radius: 8px; font-size: 0.85rem; color: #666;"> |
| <strong>Real DQN Integration:</strong><br> |
| β’ Uses trained DQN models<br> |
| β’ Shows actual perfusion parameters<br> |
| β’ Real AI decision making<br> |
| β’ Live 24-hour simulation |
| </div> |
| """) |
| |
| |
| gr.HTML("<h4>π¬ Live Monitoring Feed</h4>") |
| message_output = gr.Markdown( |
| value="π€ **Welcome!** Select a scenario and start evaluation to see real-time DQN performance.", |
| label="Messages", |
| height=300 |
| ) |
| |
| |
| start_btn.click( |
| fn=start_simulation, |
| inputs=[scenario_input], |
| outputs=[message_output, plot_output, status_display] |
| ) |
| |
| stop_btn.click( |
| fn=stop_simulation, |
| outputs=[message_output, plot_output, status_display] |
| ) |
| |
| |
| |
| timer = gr.Timer(3) |
| timer.tick( |
| fn=get_live_updates, |
| outputs=[message_output, plot_output, status_display] |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True, |
| show_error=True |
| ) |
|
|