File size: 8,121 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import signal
import socket
import time
from contextlib import contextmanager
from threading import Thread

from internlm.core.context import global_context as gpc
from internlm.monitor.alert import send_feishu_msg_with_webhook
from internlm.utils.common import SingletonMeta

from .utils import get_job_key, set_env_var


def send_alert_message(address: str = None, title: str = None, message: str = None):
    """
    Send alert messages to the given Feishu webhook address in log rank.

    Args:
        address (str): The alert address to be used to send message, defaults to None.
        title (str): The message title, defaults to None.
        message (str): The message body, defaults to None.
    """

    if address is not None and gpc.is_rank_for_log():
        send_feishu_msg_with_webhook(
            webhook=address,
            title=title if title else get_job_key(),
            message=message,
        )


class MonitorTracker(Thread):
    """
    Track job status and alert to Feishu during job training.

    Args:
        alert_address (str): The Feishu webhook address for sending alerting messages.
        check_interval (float): The interval in seconds for monitoring checks. Defaults to 300.
        loss_spike_limit (float): The threshold for detecting loss value spikes. Defaults to 1.5.
    """

    def __init__(
        self,
        alert_address: str,
        check_interval: float = 300,
        loss_spike_limit: float = 1.5,
    ):
        super().__init__()
        self.alert_address = alert_address
        self.check_interval = check_interval
        self.loss_spike_limit = loss_spike_limit
        self.last_active_time = -1
        self.last_loss_value = -1
        self.stopped = False
        self.start()

    def run(self):
        """
        start the monitor tracker.
        """

        while not self.stopped:
            try:
                self._check_stuck()
                self._check_loss_spike()
            except Exception:
                continue
            time.sleep(self.check_interval)

    def _check_stuck(self):
        """
        Check training status for potential stuck condition.
        """

        new_active_time = -1
        if os.getenv("LAST_ACTIVE_TIMESTAMP") is not None:
            new_active_time = os.getenv("LAST_ACTIVE_TIMESTAMP")
        if int(new_active_time) <= int(self.last_active_time) and new_active_time != -1:
            self._send_alert("Training may be in stuck status, please check it.")
        self.last_active_time = new_active_time

    def _check_loss_spike(self):
        """
        Check for loss value spikes.
        """

        if gpc.is_rank_for_log():
            new_loss_value = -1
            new_step_id = -1
            if os.getenv("LOSS") is not None:
                new_loss_value = os.getenv("LOSS")
            if os.getenv("STEP_ID") is not None:
                new_step_id = os.getenv("STEP_ID")

            if (float(new_loss_value) / float(self.last_loss_value)) > self.loss_spike_limit and new_loss_value != -1:
                assert int(new_step_id) >= 0
                self._send_alert(
                    f"Checking periodically: Loss spike may be happened in step {new_step_id}, "
                    f"loss value from {self.last_loss_value} to {new_loss_value}, please check it."
                )

            self.last_loss_value = new_loss_value

    def _send_alert(self, message):
        """
        Send alerting message to the Feishu webhook address.

        Args:
            message (str): The alerting message to be sent.
        """

        send_alert_message(
            address=self.alert_address,
            message=message,
        )

    def stop(self):
        """
        Stop the monitor tracker.
        """

        self.stopped = True


class MonitorManager(metaclass=SingletonMeta):
    """
    Monitor Manager for managing monitor thread and monitoring training status.
    """

    def __init__(self, loss_spike_limit: float = 1.5) -> None:
        self.monitor_thread = None
        self.loss_spike_limit = loss_spike_limit
        self.last_step_loss = -1

    def monitor_loss_spike(self, alert_address: str = None, step_count: int = 0, cur_step_loss: float = 0.0):
        """Check loss value, if loss spike occurs, send alert message to Feishu."""
        set_env_var(key="LOSS", value=cur_step_loss)
        set_env_var(key="STEP_ID", value=step_count)

        if self.last_step_loss != -1 and cur_step_loss > self.loss_spike_limit * self.last_step_loss:
            send_alert_message(
                address=alert_address,
                message=(
                    f"Checking step by step: Loss spike may be happened in step {step_count}, "
                    f"loss value from {self.last_step_loss} to {cur_step_loss}, please check it."
                ),
            )
        self.last_step_loss = cur_step_loss

    def monitor_exception(self, alert_address: str = None, excp_info: str = None):
        """Catch and format exception information, send alert message to Feishu."""
        filtered_trace = excp_info.split("\n")[-10:]
        format_trace = ""
        for line in filtered_trace:
            format_trace += "\n" + line
        send_alert_message(
            address=alert_address,
            message=f"Catch Exception from {socket.gethostname()} with rank id {gpc.get_global_rank()}:{format_trace}",
        )

    def handle_sigterm(self, alert_address: str = None):
        """Catch SIGTERM signal, and send alert message to Feishu."""

        def sigterm_handler(sys_signal, frame):
            print("receive frame: ", frame)
            print("receive signal: ", sys_signal)
            send_alert_message(
                address=alert_address,
                message=f"Process received signal {signal} and exited.",
            )

        signal.signal(signal.SIGTERM, sigterm_handler)

    def start_monitor(
        self,
        job_name: str,
        alert_address: str,
        monitor_interval_seconds: int = 300,
        loss_spike_limit: float = 1.5,
    ):
        """
        Initialize and start monitor thread for checking training job status, loss spike and so on.

        Args:
            job_name (str): The training job name.
            alert_address (str): The Feishu webhook address for sending alert messages.
            monitor_interval_seconds (int): The time of monitor interval in seconds, defaults to 300.
            loss_spike_limit (float): The limit multiple of current loss to previous loss value, which means loss spike
                may be occurs, defaults to 1.5.
        """

        # initialize some variables for monitoring
        set_env_var(key="JOB_NAME", value=job_name)

        # start a monitor thread, periodically check the training status
        self.monitor_thread = MonitorTracker(
            alert_address=alert_address,
            check_interval=monitor_interval_seconds,
            loss_spike_limit=loss_spike_limit,
        )

    def stop_monitor(self):
        """Stop the monitor and alert thread."""
        if self.monitor_thread is not None:
            self.monitor_thread.stop()


monitor_manager = MonitorManager()


@contextmanager
def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
    """
    Initialize monitor manager for monitoring training lifetime and alerting exception info to Feishu.

    Args:
        job_name (str): The training job name.
        alert_address (str): The Feishu webhook address for sending alert messages.
    """

    if alert_address is not None:
        try:
            monitor_manager.start_monitor(job_name=job_name, alert_address=alert_address)
            monitor_manager.handle_sigterm(alert_address=alert_address)
            send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
            yield
        finally:
            send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
            monitor_manager.stop_monitor()
    else:
        yield