AskIDF / chat.py
Jesus Sanchez
fixed sql chain
54cda27
raw
history blame
1.5 kB
from langchain.llms.base import get_prompts
from sqlalchemy import label
import streamlit as st
from typing import Callable
RESPONSE_LABEL = 'chat_response'
PROMPT_LABEL = 'chat_prompt'
class Chat:
def __init__(self):
if RESPONSE_LABEL not in st.session_state:
st.session_state[RESPONSE_LABEL] = []
if PROMPT_LABEL not in st.session_state:
st.session_state[PROMPT_LABEL] = []
def process(self, process_prompt: Callable, *args):
"""
process_prompt(promt: str, *args) -> tuple(Any, Callable)
callback to process the chat promt, it takes the promt for input
and returns a tuple with the response and a render callback
"""
# Render history
messages = zip(st.session_state[PROMPT_LABEL], st.session_state[RESPONSE_LABEL])
for prompt, (response, on_render) in list(messages)[::-1]:
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
on_render(response)
# Compute prompt
if prompt:= st.chat_input("Ask IDF Anything"):
st.session_state[PROMPT_LABEL].append(prompt)
(response, on_render) = process_prompt(prompt, *args)
st.session_state[RESPONSE_LABEL].append((response, on_render))
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
on_render(response)