File size: 806 Bytes
6800266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72da0f5
 
6800266
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

from typing import Dict, Any, List
from langchain_core.callbacks import BaseCallbackHandler
import schemas
import crud


class LogResponseCallback(BaseCallbackHandler):

    def __init__(self, user_request: schemas.UserRequest, db):
        super().__init__()
        self.user_request = user_request
        self.db = db

    def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
        """Run when chain ends running."""
        llm_response = outputs.generations[0][0].text
        message = schemas.MessageBase(message=llm_response, type='AI')
        crud.add_message(self.db, message, self.user_request.username)

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        for prompt in prompts:
            print(prompt)