File size: 2,578 Bytes
23cdeed
66ad25b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
"""
pluto/bus.py — Lightweight in-memory message bus for agent communication.

Agents write typed messages. Other agents read by role or message type.
This is the communication backbone for multi-agent coordination.
"""

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
import time

from fastapi.encoders import jsonable_encoder


@dataclass
class Message:
    sender: str        # e.g. "planner", "strategist", "worker-C3", "critic"
    msg_type: str      # e.g. "chunk_plan", "audit", "fact_tuple", "challenge", "retraction"
    payload: dict[str, Any]
    timestamp: float = field(default_factory=time.perf_counter)


class MessageBus:
    """
    Append-only message log shared across all agents in one pipeline run.
    Each run gets a fresh bus instance — no cross-run state.
    """

    def __init__(self) -> None:
        self._messages: list[Message] = []
        self._listeners: list[Callable[[str, str, dict], None]] = []

    def subscribe(self, callback: Callable[[str, str, dict], None]) -> None:
        """Register a callback(sender, msg_type, payload) for new messages."""
        self._listeners.append(callback)

    def post(self, sender: str, msg_type: str, payload: dict[str, Any]) -> None:
        """Post a message to the bus."""
        safe_payload = jsonable_encoder(payload)
        msg = Message(sender=sender, msg_type=msg_type, payload=safe_payload)
        self._messages.append(msg)
        for cb in self._listeners:
            try:
                cb(sender, msg_type, safe_payload)
            except:
                pass

    def read(self, msg_type: str | None = None, sender: str | None = None) -> list[Message]:
        """Read messages, optionally filtered by type and/or sender."""
        msgs = self._messages
        if msg_type:
            msgs = [m for m in msgs if m.msg_type == msg_type]
        if sender:
            msgs = [m for m in msgs if m.sender == sender]
        return msgs

    def latest(self, msg_type: str) -> Message | None:
        """Return the most recent message of a given type."""
        matches = self.read(msg_type=msg_type)
        return matches[-1] if matches else None

    def dump(self) -> list[dict]:
        """Export all messages for tracing/debugging."""
        return [
            {
                "sender": m.sender,
                "type": m.msg_type,
                "payload": m.payload,
                "t": round(m.timestamp, 4),
            }
            for m in self._messages
        ]