File size: 5,191 Bytes
0913c52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Workflow Monitor for Real-time Progress Tracking

This module provides callback hooks to monitor workflow progress in real-time.
"""

import queue
import time
from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import Any, Callable


class PhaseType(Enum):
    """Workflow phase types."""

    IDEATION_LITERATURE_SEARCH = "ideation_literature_search"
    IDEATION_ANALYZE_PAPERS = "ideation_analyze_papers"
    IDEATION_GENERATE_IDEAS = "ideation_generate_ideas"
    IDEATION_NOVELTY_CHECK = "ideation_novelty_check"
    IDEATION_REPORT = "ideation_report"

    DATA_PLANNING = "data_planning"
    DATA_EXECUTION = "data_execution"
    DATA_PAPER_SEARCH = "data_paper_search"
    DATA_FINALIZE = "data_finalize"

    EXPERIMENT_INIT = "experiment_init"
    EXPERIMENT_CODING = "experiment_coding"
    EXPERIMENT_EXEC = "experiment_exec"
    EXPERIMENT_SUMMARY = "experiment_summary"
    EXPERIMENT_ANALYSIS = "experiment_analysis"
    EXPERIMENT_REVISION = "experiment_revision"

    COMPLETE = "complete"
    ERROR = "error"


@dataclass
class ProgressUpdate:
    """A single progress update."""

    timestamp: float
    phase: PhaseType
    status: str  # "started", "progress", "completed", "error"
    message: str
    data: dict[str, Any] | None = None
    agent_name: str | None = None  # Name of the agent/subagent that generated this
    message_type: str = "status"  # "status", "thought", "action", "result", "error"
    node_name: str | None = None  # Name of the node that generated this
    intermediate_output: dict[str, Any] | None = None  # Node's intermediate output/state


class WorkflowMonitor:
    """Monitor workflow progress with real-time updates."""

    def __init__(self):
        self.updates: list[ProgressUpdate] = []
        self.update_queue: queue.Queue = queue.Queue()
        self.lock = Lock()
        self.callbacks: list[Callable[[ProgressUpdate], None]] = []

    def add_callback(self, callback: Callable[[ProgressUpdate], None]):
        """Add a callback function to be called on each update."""
        with self.lock:
            self.callbacks.append(callback)

    def log_update(
        self,
        phase: PhaseType,
        status: str,
        message: str,
        data: dict[str, Any] | None = None,
        agent_name: str | None = None,
        message_type: str = "status",
        node_name: str | None = None,
        intermediate_output: dict[str, Any] | None = None,
    ):
        """Log a progress update."""
        update = ProgressUpdate(
            timestamp=time.time(),
            phase=phase,
            status=status,
            message=message,
            data=data or {},
            agent_name=agent_name,
            message_type=message_type,
            node_name=node_name,
            intermediate_output=intermediate_output,
        )

        with self.lock:
            self.updates.append(update)
            self.update_queue.put(update)

            # Call all registered callbacks
            for callback in self.callbacks:
                try:
                    callback(update)
                except Exception as e:
                    print(f"Error in callback: {e}")

    def log_node_update(
        self,
        phase: PhaseType,
        node_name: str,
        status: str,
        message: str,
        intermediate_output: dict[str, Any] | None = None,
        agent_name: str | None = None,
        message_type: str = "status",
    ):
        """Log a node-level progress update with intermediate output."""
        self.log_update(
            phase=phase,
            status=status,
            message=message,
            agent_name=agent_name,
            message_type=message_type,
            node_name=node_name,
            intermediate_output=intermediate_output,
        )

    def get_updates(self) -> list[ProgressUpdate]:
        """Get all updates."""
        with self.lock:
            return self.updates.copy()

    def get_latest_updates(self, count: int = 10) -> list[ProgressUpdate]:
        """Get the latest N updates."""
        with self.lock:
            return self.updates[-count:]

    def get_updates_by_phase(self, phase: PhaseType) -> list[ProgressUpdate]:
        """Get all updates for a specific phase."""
        with self.lock:
            return [u for u in self.updates if u.phase == phase]

    def clear(self):
        """Clear all updates."""
        with self.lock:
            self.updates.clear()
            # Clear the queue
            while not self.update_queue.empty():
                try:
                    self.update_queue.get_nowait()
                except queue.Empty:
                    break


# Global monitor instance
_global_monitor: WorkflowMonitor | None = None


def get_monitor() -> WorkflowMonitor:
    """Get the global workflow monitor instance."""
    global _global_monitor
    if _global_monitor is None:
        _global_monitor = WorkflowMonitor()
    return _global_monitor


def reset_monitor():
    """Reset the global monitor."""
    global _global_monitor
    if _global_monitor:
        _global_monitor.clear()
    _global_monitor = WorkflowMonitor()