File size: 4,904 Bytes
fae3b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st

from src.ollama_chain import OllamaChain, OllamaRAGChain
from src.llama_cpp_chains import LlamaChain
from src.pdf_handler import extract_pdf
from src.vqa import answer_visual_question
from src.audio_processor import AudioProcessor
from langchain_community.chat_message_histories import StreamlitChatMessageHistory

from dotenv import load_dotenv
import os

load_dotenv()

audio_processor = AudioProcessor()

@st.cache_resource
def load_chain(_chat_memory):
    if st.session_state.pdf_chat:
        return OllamaRAGChain(_chat_memory)
    else:
        return OllamaChain(_chat_memory)


def file_uploader_change():
    if st.session_state.uploaded_file:
        if not st.session_state.pdf_chat:
            clear_cache()
            st.session_state.pdf_chat = True

        st.session_state.knowledge_change = True

    else:
        clear_cache()
        st.session_state.pdf_chat = False


def toggle_pdf_chat_change():
    clear_cache()
    if st.session_state.pdf_chat and st.session_state.uploaded_file:
        st.session_state.knowledge_change = True


def clear_input_field():
    # store the question
    st.session_state.user_question = st.session_state.user_input
    # clear the variable
    st.session_state.user_input = ""


def set_send_input():
    st.session_state.send_input = True
    clear_input_field()


def clear_cache():
    st.cache_resource.clear()


def initial_session_state():
    st.session_state.send_input = False
    st.session_state.knowledge_change = False


def main():
    # Initialize
    # Title
    st.title('OVERKILL LLM')
    chat_container = st.container()

    # sidebar
    # st.sidebar.title('Chat Session')

    # file upload
    st.sidebar.toggle('PDF Chat', value=False, key='pdf_chat', on_change=toggle_pdf_chat_change)
    uploaded_pdf = st.sidebar.file_uploader('Upload your pdf files',
                                            type='pdf',
                                            accept_multiple_files=True,
                                            key='uploaded_file',
                                            on_change=file_uploader_change)

    # Image upload
    uploaded_image = st.sidebar.file_uploader('Upload Images', type=['jpg', 'jpeg', 'png'], key='uploaded_image')

    # Audio upload
    uploaded_audio = st.sidebar.file_uploader('Upload Audio', type=['wav', 'mp3'], key='uploaded_audio')

    # Input objects
    user_input = st.text_input('Message OVERKILL', key='user_input', on_change=set_send_input)
    send_button = st.button('Send', key='send_button')

    # Session state
    if 'send_input' not in st.session_state:
        initial_session_state()
    # ----------------------------------------------------------------------------------------------------

    chat_history = StreamlitChatMessageHistory(key='history')

    with chat_container:
        for msg in chat_history.messages:
            st.chat_message(msg.type).write(msg.content)

    llm_chain = load_chain(chat_history)
    if st.session_state.knowledge_change:
        with st.spinner('Updating knowledge base'):
            llm_chain.update_chain(uploaded_pdf)
            st.session_state.knowledge_change = False

    # we use "or" operation here because user can press 'Enter' instead of 'Send' button
    if (send_button or st.session_state.send_input) and st.session_state.user_question != "":
        with chat_container:
            st.chat_message('user').write(st.session_state.user_question)
            
            if uploaded_image:
                image_path = os.path.join('./.cache/temp_files', uploaded_image.name)
                with open(image_path, 'wb') as f:
                    f.write(uploaded_image.getvalue())
                llm_response = answer_visual_question(image_path, st.session_state.user_question)
            elif uploaded_audio:
                audio_path = os.path.join('./.cache/temp_files', uploaded_audio.name)
                with open(audio_path, 'wb') as f:
                    f.write(uploaded_audio.getvalue())
                st.write(f"Processing audio file: {audio_path}")  # Debug statement
                st.session_state.user_question = audio_processor.audio_to_text(audio_path)
                st.write(f"Converted audio to text: {st.session_state.user_question}")  # Debug statement
                llm_response = llm_chain.run(user_input=st.session_state.user_question)
            else:
                llm_response = llm_chain.run(user_input=st.session_state.user_question)
            
            st.session_state.user_question = ""
            st.chat_message('ai').write(llm_response)

            # Convert response to speech and play it
            audio_file = audio_processor.text_to_speech(llm_response)
            audio_bytes = open(audio_file, 'rb').read()
            st.audio(audio_bytes, format='audio/mp3')


if __name__ == '__main__':
    main()