File size: 1,504 Bytes
25e33cc
1d53447
e9f9de0
 
 
 
 
 
1d53447
e9f9de0
 
1d53447
e9f9de0
 
 
 
 
 
 
54cda27
e9f9de0
 
 
 
 
 
 
 
25e33cc
e9f9de0
e395a4e
e9f9de0
 
 
 
54cda27
e9f9de0
 
 
 
 
e395a4e
e9f9de0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from langchain.llms.base import get_prompts
from sqlalchemy import label
import streamlit as st
from typing import Callable



RESPONSE_LABEL = 'chat_response'
PROMPT_LABEL = 'chat_prompt'

class Chat:

    def __init__(self): 
        if RESPONSE_LABEL not in st.session_state:
            st.session_state[RESPONSE_LABEL] = []

        if PROMPT_LABEL not in st.session_state:
            st.session_state[PROMPT_LABEL] = []

    def process(self, process_prompt: Callable, *args):
        """
        process_prompt(promt: str, *args) -> tuple(Any, Callable)
            callback to process the chat promt, it takes the promt for input
            and returns a tuple with the response and a render callback
        """

        # Render history
        messages = zip(st.session_state[PROMPT_LABEL], st.session_state[RESPONSE_LABEL])
        for prompt, (response, on_render) in list(messages)[::-1]:
            with st.chat_message("user"):
                st.write(prompt)
            with st.chat_message("assistant"):
                on_render(response)

        # Compute prompt
        if prompt:= st.chat_input("Ask IDF Anything"):
            st.session_state[PROMPT_LABEL].append(prompt)
            (response, on_render) = process_prompt(prompt, *args)
            st.session_state[RESPONSE_LABEL].append((response, on_render))

            with st.chat_message("user"):
                st.write(prompt)

            with st.chat_message("assistant"):
                on_render(response)