| """Callback handlers used in the app.""" |
| from typing import Any, Dict, List |
|
|
| from langchain.callbacks.base import AsyncCallbackHandler |
|
|
| from models import ChatResponse |
|
|
|
|
| class StreamingLLMCallbackHandler(AsyncCallbackHandler): |
| """Callback handler for streaming LLM responses.""" |
|
|
| def __init__(self, websocket): |
| self.websocket = websocket |
|
|
| async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
| resp = ChatResponse(sender="bot", message=token, type="stream") |
| await self.websocket.send_json(resp.dict()) |
|
|
|
|
| class QuestionGenCallbackHandler(AsyncCallbackHandler): |
| """Callback handler for question generation.""" |
|
|
| def __init__(self, websocket): |
| self.websocket = websocket |
|
|
| async def on_llm_start( |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
| ) -> None: |
| """Run when LLM starts running.""" |
| resp = ChatResponse( |
| sender="bot", message="Synthesizing question...", type="info" |
| ) |
| await self.websocket.send_json(resp.dict()) |
|
|