File size: 4,149 Bytes
cbdbe9c 4313e45 f0d663b b8b0675 34328de b8b0675 079959b b8b0675 18e8d0d b8b0675 a0d509e f0220fd c465bdd f85115e 764a615 c465bdd b8b0675 4313e45 a0d509e b8b0675 a0d509e b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 4313e45 b8b0675 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import torch
import streamlit as st
from streamlit_chat import message
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.llms.huggingface_hub import HuggingFaceHub
from typing import Dict
import json
from io import StringIO
from random import randint
from transformers import AutoTokenizer
st.set_page_config(page_title="Document Analysis", page_icon=":robot:")
st.header("Chat with your bot (Model: Falcon-7B-Instruct)")
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
class CustomHuggingFaceEndpoint(HuggingFaceHub):
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
len_prompt = len(prompt)
input_str = json.dumps({
"inputs": prompt,
"parameters": {
"max_new_tokens": 100,
"stop": ["Human:"],
"do_sample": False,
"repetition_penalty": 1.1
}
})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
response_json = output.decode('utf-8')
res = json.loads(response_json)
ans = res[0]['generated_text'][self.len_prompt:]
ans = ans[:ans.rfind("Human")].strip()
return ans
def load_chain():
llm = CustomHuggingFaceEndpoint(repo_id=model_name)
memory = ConversationBufferMemory()
chain = ConversationChain(llm=llm, memory=memory)
return chain
chatchain = load_chain()
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
chatchain.memory.clear()
if 'widget_key' not in st.session_state:
st.session_state['widget_key'] = str(randint(1000, 100000000))
# Sidebar - the clear button is will flush the memory of the conversation
st.sidebar.title("Sidebar")
clear_button = st.sidebar.button("Clear Conversation", key="clear")
if clear_button:
st.session_state['generated'] = []
st.session_state['past'] = []
st.session_state['widget_key'] = str(randint(1000, 100000000))
chatchain.memory.clear()
# upload file button
uploaded_file = st.sidebar.file_uploader("Upload a txt file", type=["txt"], key=st.session_state['widget_key'])
# this is the container that displays the past conversation
response_container = st.container()
# this is the container with the input text box
container = st.container()
with container:
# define the input text box
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
# when the submit button is pressed we send the user query to the chatchain object and save the chat history
if submit_button and user_input:
output = chatchain(user_input)["response"]
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
# when a file is uploaded we also send the content to the chatchain object and ask for confirmation
elif uploaded_file is not None:
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
content = "=== BEGIN FILE ===\n"
content += stringio.read().strip()
content += "\n=== END FILE ===\nPlease confirm that you have read that file by saying 'Yes, I have read the file'"
output = chatchain(content)["response"]
st.session_state['past'].append("I have uploaded a file. Please confirm that you have read that file.")
st.session_state['generated'].append(output)
history = chatchain.memory.load_memory_variables({})["history"]
tokens = tokenizer.tokenize(history)
st.write(f"Number of tokens in memory: {len(tokens)}")
# this loop is responsible for displaying the chat history
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i)) |