|
|
from typing import Any |
|
|
|
|
|
import socketio |
|
|
from loguru import logger |
|
|
|
|
|
from langflow.services.base import Service |
|
|
from langflow.services.cache.base import AsyncBaseCacheService, CacheService |
|
|
from langflow.services.deps import get_chat_service |
|
|
from langflow.services.socket.utils import build_vertex, get_vertices |
|
|
|
|
|
|
|
|
class SocketIOService(Service): |
|
|
name = "socket_service" |
|
|
|
|
|
def __init__(self, cache_service: CacheService | AsyncBaseCacheService): |
|
|
self.cache_service = cache_service |
|
|
|
|
|
def init(self, sio: socketio.AsyncServer) -> None: |
|
|
|
|
|
self.sio = sio |
|
|
if self.sio: |
|
|
self.sio.event(self.connect) |
|
|
self.sio.event(self.disconnect) |
|
|
self.sio.on("message")(self.message) |
|
|
self.sio.on("get_vertices")(self.on_get_vertices) |
|
|
self.sio.on("build_vertex")(self.on_build_vertex) |
|
|
self.sessions = {} |
|
|
|
|
|
async def emit_error(self, sid, error) -> None: |
|
|
await self.sio.emit("error", to=sid, data=error) |
|
|
|
|
|
async def connect(self, sid, environ) -> None: |
|
|
logger.info(f"Socket connected: {sid}") |
|
|
self.sessions[sid] = environ |
|
|
|
|
|
async def disconnect(self, sid) -> None: |
|
|
logger.info(f"Socket disconnected: {sid}") |
|
|
self.sessions.pop(sid, None) |
|
|
|
|
|
async def message(self, sid, data=None) -> None: |
|
|
|
|
|
await self.emit_message(to=sid, data=data or {"foo": "bar", "baz": [1, 2, 3]}) |
|
|
|
|
|
async def emit_message(self, to, data) -> None: |
|
|
|
|
|
await self.sio.emit("message", to=to, data=data) |
|
|
|
|
|
async def emit_token(self, to, data) -> None: |
|
|
await self.sio.emit("token", to=to, data=data) |
|
|
|
|
|
async def on_get_vertices(self, sid, flow_id) -> None: |
|
|
await get_vertices(self.sio, sid, flow_id, get_chat_service()) |
|
|
|
|
|
async def on_build_vertex(self, sid, flow_id, vertex_id) -> None: |
|
|
await build_vertex( |
|
|
sio=self.sio, |
|
|
sid=sid, |
|
|
flow_id=flow_id, |
|
|
vertex_id=vertex_id, |
|
|
get_cache=self.get_cache, |
|
|
set_cache=self.set_cache, |
|
|
) |
|
|
|
|
|
async def get_cache(self, sid: str) -> Any: |
|
|
"""Get the cache for a client.""" |
|
|
value = self.cache_service.get(sid) |
|
|
if isinstance(self.cache_service, AsyncBaseCacheService): |
|
|
return await value |
|
|
return value |
|
|
|
|
|
async def set_cache(self, sid: str, build_result: Any) -> bool: |
|
|
"""Set the cache for a client.""" |
|
|
|
|
|
|
|
|
|
|
|
result_dict = { |
|
|
"result": build_result, |
|
|
"type": type(build_result), |
|
|
} |
|
|
result = self.cache_service.upsert(sid, result_dict) |
|
|
if isinstance(self.cache_service, AsyncBaseCacheService): |
|
|
await result |
|
|
return await self.cache_service.contains(sid) |
|
|
return sid in self.cache_service |
|
|
|