File size: 4,149 Bytes
cbdbe9c
4313e45
f0d663b
b8b0675
 
34328de
b8b0675
 
 
 
 
 
 
079959b
b8b0675
18e8d0d
b8b0675
 
a0d509e
f0220fd
c465bdd
f85115e
 
 
 
 
 
 
 
 
764a615
c465bdd
b8b0675
 
 
 
 
 
4313e45
a0d509e
b8b0675
a0d509e
b8b0675
 
 
4313e45
b8b0675
4313e45
b8b0675
 
 
 
 
 
 
4313e45
b8b0675
 
 
4313e45
b8b0675
 
 
 
 
4313e45
b8b0675
 
4313e45
b8b0675
 
 
 
4313e45
b8b0675
 
 
 
 
4313e45
b8b0675
 
 
 
 
 
 
 
 
 
 
 
 
 
4313e45
b8b0675
 
 
4313e45
b8b0675
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import streamlit as st
from streamlit_chat import message
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.llms.huggingface_hub import HuggingFaceHub
from typing import Dict
import json
from io import StringIO
from random import randint
from transformers import AutoTokenizer

st.set_page_config(page_title="Document Analysis", page_icon=":robot:")
st.header("Chat with your bot (Model: Falcon-7B-Instruct)")

model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

class CustomHuggingFaceEndpoint(HuggingFaceHub):
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        len_prompt = len(prompt)
        input_str = json.dumps({
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 100,
                "stop": ["Human:"],
                "do_sample": False,
                "repetition_penalty": 1.1
            }
        })
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        response_json = output.decode('utf-8')
        res = json.loads(response_json)
        ans = res[0]['generated_text'][self.len_prompt:]
        ans = ans[:ans.rfind("Human")].strip()
        return ans


def load_chain():
    llm = CustomHuggingFaceEndpoint(repo_id=model_name)
    memory = ConversationBufferMemory()
    chain = ConversationChain(llm=llm, memory=memory)
    return chain

chatchain = load_chain()

if 'generated' not in st.session_state:
    st.session_state['generated'] = []
if 'past' not in st.session_state:
    st.session_state['past'] = []
    chatchain.memory.clear()
if 'widget_key' not in st.session_state:
    st.session_state['widget_key'] = str(randint(1000, 100000000))

# Sidebar - the clear button is will flush the memory of the conversation
st.sidebar.title("Sidebar")
clear_button = st.sidebar.button("Clear Conversation", key="clear")

if clear_button:
    st.session_state['generated'] = []
    st.session_state['past'] = []
    st.session_state['widget_key'] = str(randint(1000, 100000000))
    chatchain.memory.clear()

# upload file button
uploaded_file = st.sidebar.file_uploader("Upload a txt file", type=["txt"], key=st.session_state['widget_key'])

# this is the container that displays the past conversation
response_container = st.container()
# this is the container with the input text box
container = st.container()

with container:
    # define the input text box
    with st.form(key='my_form', clear_on_submit=True):
        user_input = st.text_area("You:", key='input', height=100)
        submit_button = st.form_submit_button(label='Send')

    # when the submit button is pressed we send the user query to the chatchain object and save the chat history
    if submit_button and user_input:
        output = chatchain(user_input)["response"]
        st.session_state['past'].append(user_input)
        st.session_state['generated'].append(output)
    # when a file is uploaded we also send the content to the chatchain object and ask for confirmation
    elif uploaded_file is not None:
        stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
        content = "=== BEGIN FILE ===\n"
        content += stringio.read().strip()
        content += "\n=== END FILE ===\nPlease confirm that you have read that file by saying 'Yes, I have read the file'"
        output = chatchain(content)["response"]
        st.session_state['past'].append("I have uploaded a file. Please confirm that you have read that file.")
        st.session_state['generated'].append(output)

    history = chatchain.memory.load_memory_variables({})["history"]
    tokens = tokenizer.tokenize(history)
    st.write(f"Number of tokens in memory: {len(tokens)}")

# this loop is responsible for displaying the chat history
if st.session_state['generated']:
    with response_container:
        for i in range(len(st.session_state['generated'])):
            message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
            message(st.session_state["generated"][i], key=str(i))