Spaces:
Runtime error
Runtime error
Commit
·
593f0ea
0
Parent(s):
Initial commit for Hugging Face Space
Browse files- .env +1 -0
- app.py +148 -0
- config.yaml +7 -0
- packages.txt +2 -0
- requirements.txt +24 -0
- src/__pycache__/audio_processor.cpython-311.pyc +0 -0
- src/__pycache__/llama_cpp_chains.cpython-311.pyc +0 -0
- src/__pycache__/ollama_chain.cpython-311.pyc +0 -0
- src/__pycache__/pdf_handler.cpython-311.pyc +0 -0
- src/__pycache__/utils.cpython-311.pyc +0 -0
- src/__pycache__/vectorstore.cpython-311.pyc +0 -0
- src/__pycache__/vqa.cpython-311.pyc +0 -0
- src/audio_processor.py +38 -0
- src/llama_cpp_chains.py +38 -0
- src/ollama_chain.py +134 -0
- src/pdf_handler.py +49 -0
- src/utils.py +9 -0
- src/vectorstore.py +90 -0
- src/vqa.py +34 -0
.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
PINECONE_API_KEY = "pcsk_2dE5m6_B9WcbjZ1GcfT6p19rSXwMF2ULoqtc11xrXgngyhALBzmcrrxPLVM83xeKq537HX"
|
app.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from src.ollama_chain import OllamaChain, OllamaRAGChain
|
| 4 |
+
from src.llama_cpp_chains import LlamaChain
|
| 5 |
+
from src.pdf_handler import extract_pdf
|
| 6 |
+
from src.vqa import answer_visual_question
|
| 7 |
+
from src.audio_processor import AudioProcessor
|
| 8 |
+
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
audio_processor = AudioProcessor()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@st.cache_resource
|
| 19 |
+
def load_chain(_chat_memory):
|
| 20 |
+
if st.session_state.pdf_chat:
|
| 21 |
+
return OllamaRAGChain(_chat_memory)
|
| 22 |
+
else:
|
| 23 |
+
return OllamaChain(_chat_memory)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def file_uploader_change():
|
| 27 |
+
if st.session_state.uploaded_file:
|
| 28 |
+
if not st.session_state.pdf_chat:
|
| 29 |
+
clear_cache()
|
| 30 |
+
st.session_state.knowledge_change = True
|
| 31 |
+
else:
|
| 32 |
+
clear_cache()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def toggle_pdf_chat_change():
|
| 36 |
+
clear_cache()
|
| 37 |
+
if st.session_state.pdf_chat and st.session_state.uploaded_file:
|
| 38 |
+
st.session_state.knowledge_change = True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def clear_input_field():
|
| 42 |
+
st.session_state.user_question = st.session_state.user_input
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def set_send_input():
|
| 46 |
+
st.session_state.send_input = True
|
| 47 |
+
clear_input_field()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def clear_cache():
|
| 51 |
+
st.cache_resource.clear()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def initial_session_state():
|
| 55 |
+
st.session_state.send_input = False
|
| 56 |
+
st.session_state.knowledge_change = False
|
| 57 |
+
st.session_state.user_question = ""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
st.title('OVERKILL LLM')
|
| 62 |
+
os.makedirs('./.cache/temp_files', exist_ok=True) # Ensure temp folder exists
|
| 63 |
+
chat_container = st.container()
|
| 64 |
+
|
| 65 |
+
# Sidebar
|
| 66 |
+
st.sidebar.toggle('PDF Chat', value=False, key='pdf_chat', on_change=toggle_pdf_chat_change)
|
| 67 |
+
uploaded_pdf = st.sidebar.file_uploader(
|
| 68 |
+
'Upload your pdf files',
|
| 69 |
+
type='pdf',
|
| 70 |
+
accept_multiple_files=True,
|
| 71 |
+
key='uploaded_file',
|
| 72 |
+
on_change=file_uploader_change
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
uploaded_image = st.sidebar.file_uploader('Upload Images', type=['jpg', 'jpeg', 'png'], key='uploaded_image')
|
| 76 |
+
st.sidebar.file_uploader('Upload Audio', type=['wav', 'mp3'], key='uploaded_audio')
|
| 77 |
+
uploaded_audio = st.session_state.get('uploaded_audio')
|
| 78 |
+
|
| 79 |
+
# Optional reset
|
| 80 |
+
if st.sidebar.button("🔄 Reset Chat"):
|
| 81 |
+
st.session_state.clear()
|
| 82 |
+
st.experimental_rerun()
|
| 83 |
+
|
| 84 |
+
# Input objects
|
| 85 |
+
user_input = st.text_input('Message OVERKILL', key='user_input', on_change=set_send_input)
|
| 86 |
+
send_button = st.button('Send', key='send_button')
|
| 87 |
+
|
| 88 |
+
# Initial session setup
|
| 89 |
+
if 'send_input' not in st.session_state or 'user_question' not in st.session_state:
|
| 90 |
+
initial_session_state()
|
| 91 |
+
|
| 92 |
+
chat_history = StreamlitChatMessageHistory(key='history')
|
| 93 |
+
|
| 94 |
+
with chat_container:
|
| 95 |
+
for msg in chat_history.messages:
|
| 96 |
+
st.chat_message(msg.type).write(msg.content)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
llm_chain = load_chain(chat_history)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
st.error(f"Error loading LLM chain: {e}")
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
if st.session_state.knowledge_change:
|
| 105 |
+
with st.spinner('Updating knowledge base'):
|
| 106 |
+
try:
|
| 107 |
+
llm_chain.update_chain(uploaded_pdf)
|
| 108 |
+
st.session_state.knowledge_change = False
|
| 109 |
+
except Exception as e:
|
| 110 |
+
st.error(f"Error updating knowledge base: {e}")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
if (send_button or st.session_state.send_input) and st.session_state.user_question != "":
|
| 114 |
+
with chat_container:
|
| 115 |
+
st.chat_message('user').write(st.session_state.user_question)
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
if uploaded_image:
|
| 119 |
+
image_path = os.path.join('./.cache/temp_files', uploaded_image.name)
|
| 120 |
+
with open(image_path, 'wb') as f:
|
| 121 |
+
f.write(uploaded_image.getvalue())
|
| 122 |
+
llm_response = answer_visual_question(image_path, st.session_state.user_question)
|
| 123 |
+
|
| 124 |
+
elif uploaded_audio:
|
| 125 |
+
audio_path = os.path.join('./.cache/temp_files', uploaded_audio.name)
|
| 126 |
+
with open(audio_path, 'wb') as f:
|
| 127 |
+
f.write(uploaded_audio.getvalue())
|
| 128 |
+
st.write(f"Processing audio file: {audio_path}")
|
| 129 |
+
question = audio_processor.audio_to_text(audio_path)
|
| 130 |
+
st.write(f"Converted audio to text: {question}")
|
| 131 |
+
llm_response = llm_chain.run(user_input=question)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
llm_response = llm_chain.run(user_input=st.session_state.user_question)
|
| 135 |
+
|
| 136 |
+
st.session_state.user_question = ""
|
| 137 |
+
st.chat_message('ai').write(llm_response)
|
| 138 |
+
|
| 139 |
+
audio_file = audio_processor.text_to_speech(llm_response)
|
| 140 |
+
audio_bytes = open(audio_file, 'rb').read()
|
| 141 |
+
st.audio(audio_bytes, format='audio/mp3')
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
st.error(f"Error during chat: {e}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == '__main__':
|
| 148 |
+
main()
|
config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chat_model:
|
| 2 |
+
'model': "llama3:latest"
|
| 3 |
+
'temperature': 0.75
|
| 4 |
+
'num_gpu': 1
|
| 5 |
+
vector_database:
|
| 6 |
+
chroma:
|
| 7 |
+
chat_session_path: './chat_session/'
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
poppler-utils
|
| 2 |
+
ffmpeg
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
torchaudio
|
| 5 |
+
langchain
|
| 6 |
+
langchain-community
|
| 7 |
+
langchain-pinecone
|
| 8 |
+
langchain-chroma
|
| 9 |
+
pinecone-client
|
| 10 |
+
sentence-transformers
|
| 11 |
+
pypdf
|
| 12 |
+
PyMuPDF
|
| 13 |
+
pdf2image
|
| 14 |
+
pillow
|
| 15 |
+
opencv-python
|
| 16 |
+
ffmpeg-python
|
| 17 |
+
gtts
|
| 18 |
+
pydub
|
| 19 |
+
speechrecognition
|
| 20 |
+
streamlit>=1.32
|
| 21 |
+
google-generativeai
|
| 22 |
+
requests
|
| 23 |
+
python-dotenv
|
| 24 |
+
PyYAML
|
src/__pycache__/audio_processor.cpython-311.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
src/__pycache__/llama_cpp_chains.cpython-311.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
src/__pycache__/ollama_chain.cpython-311.pyc
ADDED
|
Binary file (7.36 kB). View file
|
|
|
src/__pycache__/pdf_handler.cpython-311.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
src/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (626 Bytes). View file
|
|
|
src/__pycache__/vectorstore.cpython-311.pyc
ADDED
|
Binary file (4.94 kB). View file
|
|
|
src/__pycache__/vqa.cpython-311.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
src/audio_processor.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import speech_recognition as sr
|
| 2 |
+
from gtts import gTTS
|
| 3 |
+
import tempfile
|
| 4 |
+
from pydub import AudioSegment
|
| 5 |
+
|
| 6 |
+
class AudioProcessor:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.recognizer = sr.Recognizer()
|
| 9 |
+
|
| 10 |
+
def audio_to_text(self, audio_file):
|
| 11 |
+
"""Process an uploaded audio file and convert it to text."""
|
| 12 |
+
try:
|
| 13 |
+
# Convert audio file to WAV format
|
| 14 |
+
audio = AudioSegment.from_file(audio_file)
|
| 15 |
+
wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 16 |
+
audio.export(wav_file.name, format="wav")
|
| 17 |
+
print(f"Converted audio to WAV: {wav_file.name}") # Debug statement
|
| 18 |
+
|
| 19 |
+
with sr.AudioFile(wav_file.name) as source:
|
| 20 |
+
audio = self.recognizer.record(source)
|
| 21 |
+
try:
|
| 22 |
+
text = self.recognizer.recognize_google(audio)
|
| 23 |
+
except sr.UnknownValueError:
|
| 24 |
+
text = "Could not understand audio"
|
| 25 |
+
except sr.RequestError:
|
| 26 |
+
text = "Could not request results"
|
| 27 |
+
print(f"Recognized text: {text}") # Debug statement
|
| 28 |
+
return text
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error processing audio file: {e}") # Debug statement
|
| 31 |
+
return f"Error processing audio file: {e}"
|
| 32 |
+
|
| 33 |
+
def text_to_speech(self, text):
|
| 34 |
+
"""Convert text to speech using gTTS and save as a .mp3 file."""
|
| 35 |
+
tts = gTTS(text)
|
| 36 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
|
| 37 |
+
tts.save(temp_audio.name)
|
| 38 |
+
return temp_audio.name
|
src/llama_cpp_chains.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.llms import LlamaCpp
|
| 2 |
+
from langchain.prompts import PromptTemplate
|
| 3 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from langchain_core.runnables import RunnableSequence
|
| 6 |
+
|
| 7 |
+
from src.utils import load_config
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LlamaChain:
|
| 11 |
+
def __init__(self, chat_memory) -> None:
|
| 12 |
+
prompt = PromptTemplate(
|
| 13 |
+
template="""<|begin_of_text|>
|
| 14 |
+
<|start_header_id|>system<|end_header_id|>
|
| 15 |
+
You are a helpful and knowledgeable AI assistant.
|
| 16 |
+
<|eot_id|>
|
| 17 |
+
<|start_header_id|>user<|end_header_id|>
|
| 18 |
+
Previous conversation={chat_history}
|
| 19 |
+
Question: {input}
|
| 20 |
+
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
| 21 |
+
input_variables=['chat_history', 'input']
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.memory = ConversationBufferWindowMemory(
|
| 25 |
+
memory_key='chat_history',
|
| 26 |
+
chat_memory=chat_memory,
|
| 27 |
+
k=3,
|
| 28 |
+
return_messages=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
config = load_config()
|
| 32 |
+
llm = LlamaCpp(**config['chat_model'])
|
| 33 |
+
|
| 34 |
+
self.llm_chain = RunnableSequence(prompt | llm | self.memory | StrOutputParser())
|
| 35 |
+
|
| 36 |
+
def run(self, user_input):
|
| 37 |
+
response = self.llm_chain.invoke(user_input)
|
| 38 |
+
return response['text']
|
src/ollama_chain.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.llms import Ollama
|
| 2 |
+
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
|
| 3 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 4 |
+
from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
|
| 5 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 6 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 7 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
|
| 10 |
+
from src.utils import load_config
|
| 11 |
+
from src.vectorstore import VectorDB
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def format_docs(docs: list[Document]):
|
| 15 |
+
return '\n\n'.join(doc.page_content for doc in docs)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OllamaChain:
|
| 19 |
+
def __init__(self, chat_memory) -> None:
|
| 20 |
+
prompt = PromptTemplate(
|
| 21 |
+
template="""<|begin_of_text|>
|
| 22 |
+
<|start_header_id|>system<|end_header_id|>
|
| 23 |
+
You are a honest and unbiased AI assistant
|
| 24 |
+
<|eot_id|>
|
| 25 |
+
<|start_header_id|>user<|end_header_id|>
|
| 26 |
+
Previous conversation={chat_history}
|
| 27 |
+
Question: {input}
|
| 28 |
+
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
| 29 |
+
input_variables=['chat_history', 'input']
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.memory = ConversationBufferWindowMemory(
|
| 33 |
+
memory_key='chat_history',
|
| 34 |
+
chat_memory=chat_memory,
|
| 35 |
+
k=3,
|
| 36 |
+
return_messages=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
config = load_config()
|
| 40 |
+
llm = Ollama(**config['chat_model'])
|
| 41 |
+
# llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1)
|
| 42 |
+
|
| 43 |
+
self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser())
|
| 44 |
+
# runnable = prompt | llm
|
| 45 |
+
|
| 46 |
+
def run(self, user_input):
|
| 47 |
+
response = self.llm_chain.invoke(user_input)
|
| 48 |
+
|
| 49 |
+
return response['text']
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class OllamaRAGChain:
|
| 53 |
+
def __init__(self, chat_memory, uploaded_file=None):
|
| 54 |
+
# initialize vector db using config
|
| 55 |
+
from src.utils import load_config
|
| 56 |
+
config = load_config()
|
| 57 |
+
vector_db_config = config.get('vector_database', {})
|
| 58 |
+
db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
|
| 59 |
+
index_name = 'default'
|
| 60 |
+
self.vector_db = VectorDB(db_name, index_name)
|
| 61 |
+
if uploaded_file:
|
| 62 |
+
self.update_knowledge_base(uploaded_file)
|
| 63 |
+
|
| 64 |
+
# initialize llm
|
| 65 |
+
config = load_config()
|
| 66 |
+
self.llm = Ollama(**config['chat_model'])
|
| 67 |
+
|
| 68 |
+
# initialize memory
|
| 69 |
+
self.chat_memory = chat_memory
|
| 70 |
+
|
| 71 |
+
# initialize sub chain with history message
|
| 72 |
+
contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \
|
| 73 |
+
in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \
|
| 74 |
+
standalone question which is incorporated from the latest question and history and can be understood without \
|
| 75 |
+
the chat history.
|
| 76 |
+
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
|
| 77 |
+
|
| 78 |
+
self.contextual_q_prompt = ChatPromptTemplate.from_messages(
|
| 79 |
+
[
|
| 80 |
+
('system', contextual_q_system_prompt),
|
| 81 |
+
MessagesPlaceholder('chat_history'),
|
| 82 |
+
('human', '{input}'),
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.history_aware_retriever = create_history_aware_retriever(
|
| 87 |
+
self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# initialize qa chain
|
| 91 |
+
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\
|
| 92 |
+
context to answer the question. If you don't know the answer, just say that you don't know.
|
| 93 |
+
Context: {context}"""
|
| 94 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
| 95 |
+
[
|
| 96 |
+
('system', qa_system_prompt),
|
| 97 |
+
MessagesPlaceholder('chat_history'),
|
| 98 |
+
('human', '{input}'),
|
| 99 |
+
]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
|
| 103 |
+
|
| 104 |
+
rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)
|
| 105 |
+
|
| 106 |
+
self.conversation_rag_chain = RunnableWithMessageHistory(
|
| 107 |
+
rag_chain,
|
| 108 |
+
lambda session_id: chat_memory,
|
| 109 |
+
input_messages_key='input',
|
| 110 |
+
history_messages_key='chat_history',
|
| 111 |
+
output_messages_key='answer'
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def run(self, user_input):
|
| 115 |
+
config = {"configurable": {"session_id": "any"}}
|
| 116 |
+
response = self.conversation_rag_chain.invoke({'input': user_input}, config)
|
| 117 |
+
|
| 118 |
+
return response['answer']
|
| 119 |
+
|
| 120 |
+
def update_chain(self, uploaded_pdf):
|
| 121 |
+
self.update_knowledge_base(uploaded_pdf)
|
| 122 |
+
self.history_aware_retriever = create_history_aware_retriever(
|
| 123 |
+
self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
|
| 124 |
+
)
|
| 125 |
+
self.conversation_rag_chain = RunnableWithMessageHistory(
|
| 126 |
+
create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain),
|
| 127 |
+
lambda session_id: self.chat_memory,
|
| 128 |
+
input_messages_key='input',
|
| 129 |
+
history_messages_key='chat_history',
|
| 130 |
+
output_messages_key='answer'
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def update_knowledge_base(self, uploaded_pdf):
|
| 134 |
+
self.vector_db.index(uploaded_pdf)
|
src/pdf_handler.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from langchain_community.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
|
| 4 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 5 |
+
from langchain.schema.document import Document
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_cache_dir(directory=None):
|
| 9 |
+
if not directory:
|
| 10 |
+
directory = './.cache'
|
| 11 |
+
|
| 12 |
+
os.makedirs('./.cache', exist_ok=True)
|
| 13 |
+
return directory
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_pdf(file_path):
|
| 17 |
+
loader = PyPDFLoader(file_path)
|
| 18 |
+
|
| 19 |
+
return loader.load()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_pdf_directory(directory):
|
| 23 |
+
loader = PyPDFDirectoryLoader(directory)
|
| 24 |
+
return loader.load()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def split_pdf(pdfs: list[Document]):
|
| 28 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 29 |
+
chunk_size=512,
|
| 30 |
+
chunk_overlap=64,
|
| 31 |
+
length_function=len,
|
| 32 |
+
is_separator_regex=False
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return splitter.split_documents(pdfs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def extract_pdf(uploaded_pdf):
|
| 39 |
+
cache_dir = create_cache_dir()
|
| 40 |
+
cache_dir = os.path.join(cache_dir, 'temp_files')
|
| 41 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 42 |
+
# Support both single file and list of files
|
| 43 |
+
if not isinstance(uploaded_pdf, list):
|
| 44 |
+
uploaded_pdf = [uploaded_pdf]
|
| 45 |
+
for file in uploaded_pdf:
|
| 46 |
+
file_path = os.path.join(cache_dir, file.name)
|
| 47 |
+
with open(file_path, 'wb') as w:
|
| 48 |
+
w.write(file.getvalue())
|
| 49 |
+
return cache_dir
|
src/utils.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_config():
|
| 6 |
+
with open('./config.yaml', 'r') as f:
|
| 7 |
+
config = yaml.safe_load(f)
|
| 8 |
+
|
| 9 |
+
return config
|
src/vectorstore.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pinecone import Pinecone, ServerlessSpec, PodSpec
|
| 2 |
+
from langchain_pinecone import PineconeVectorStore
|
| 3 |
+
from langchain_chroma import Chroma
|
| 4 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 5 |
+
from langchain.indexes import SQLRecordManager, index
|
| 6 |
+
|
| 7 |
+
from src.pdf_handler import extract_pdf, load_pdf_directory, split_pdf
|
| 8 |
+
from src.utils import load_config
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def setup_pinecone(index_name, embedding_model, embedding_dim, metric='cosine', use_serverless=True):
|
| 18 |
+
pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
|
| 19 |
+
if use_serverless:
|
| 20 |
+
spec = ServerlessSpec(cloud='aws', region='us-east-1')
|
| 21 |
+
else:
|
| 22 |
+
spec = PodSpec()
|
| 23 |
+
|
| 24 |
+
if index_name in pc.list_indexes().names():
|
| 25 |
+
pc.delete_index(index_name)
|
| 26 |
+
|
| 27 |
+
pc.create_index(
|
| 28 |
+
index_name,
|
| 29 |
+
dimension=embedding_dim,
|
| 30 |
+
metric=metric,
|
| 31 |
+
spec=spec
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
db = PineconeVectorStore(index_name=index_name, embedding=embedding_model)
|
| 35 |
+
return db
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def setup_chroma(index_name, embedding_model, persist_directory=None):
|
| 39 |
+
if not persist_directory:
|
| 40 |
+
persist_directory = './.cache/database'
|
| 41 |
+
|
| 42 |
+
os.makedirs(persist_directory, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
db = Chroma(index_name, embedding_function=embedding_model, persist_directory=persist_directory)
|
| 45 |
+
return db
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class VectorDB:
|
| 49 |
+
def __init__(self, db_name=None, index_name=None, cache_dir=None):
|
| 50 |
+
config = load_config()
|
| 51 |
+
vector_db_config = config.get('vector_database', {})
|
| 52 |
+
# Determine DB type from config, fallback to argument or chroma
|
| 53 |
+
if db_name is None:
|
| 54 |
+
db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
|
| 55 |
+
if index_name is None:
|
| 56 |
+
index_name = 'default'
|
| 57 |
+
embedding = OllamaEmbeddings(model='nomic-embed-text:latest', num_gpu=1)
|
| 58 |
+
if not cache_dir:
|
| 59 |
+
cache_dir = './.cache/database'
|
| 60 |
+
self.cache_dir = cache_dir
|
| 61 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 62 |
+
if db_name == 'pinecone':
|
| 63 |
+
if not os.environ.get('PINECONE_API_KEY'):
|
| 64 |
+
raise ValueError("PINECONE_API_KEY environment variable is not set. Please set it in your .env file or environment.")
|
| 65 |
+
self.vectorstore = setup_pinecone(index_name, embedding, 768, 'cosine')
|
| 66 |
+
else:
|
| 67 |
+
self.vectorstore = setup_chroma(index_name, embedding, self.cache_dir)
|
| 68 |
+
namespace = f'{db_name}/{index_name}'
|
| 69 |
+
self.record_manager = SQLRecordManager(namespace,
|
| 70 |
+
db_url=f'sqlite:///{self.cache_dir}/record_manager_cache.sql')
|
| 71 |
+
self.record_manager.create_schema()
|
| 72 |
+
|
| 73 |
+
def index(self, uploaded_file):
|
| 74 |
+
directory = extract_pdf(uploaded_file)
|
| 75 |
+
docs = load_pdf_directory(directory)
|
| 76 |
+
chunks = split_pdf(docs)
|
| 77 |
+
|
| 78 |
+
index(
|
| 79 |
+
docs_source=chunks,
|
| 80 |
+
record_manager=self.record_manager,
|
| 81 |
+
vector_store=self.vectorstore,
|
| 82 |
+
cleanup='full',
|
| 83 |
+
source_id_key='source'
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
for file in os.listdir(directory):
|
| 87 |
+
os.remove(os.path.join(directory, file))
|
| 88 |
+
|
| 89 |
+
def as_retriever(self):
|
| 90 |
+
return self.vectorstore.as_retriever()
|
src/vqa.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 3 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import BlipProcessor, BlipForQuestionAnswering, Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 8 |
+
import os
|
| 9 |
+
import torchaudio
|
| 10 |
+
|
| 11 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
| 12 |
+
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
| 13 |
+
|
| 14 |
+
audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 15 |
+
audio_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
| 16 |
+
|
| 17 |
+
def answer_visual_question(image_path_or_url: str, question: str) -> str:
|
| 18 |
+
if os.path.isfile(image_path_or_url):
|
| 19 |
+
raw_image = Image.open(image_path_or_url).convert('RGB')
|
| 20 |
+
else:
|
| 21 |
+
raw_image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
| 22 |
+
|
| 23 |
+
inputs = processor(raw_image, question, return_tensors="pt")
|
| 24 |
+
out = model.generate(**inputs)
|
| 25 |
+
return processor.decode(out[0], skip_special_tokens=True)
|
| 26 |
+
|
| 27 |
+
def transcribe_audio(audio_path: str) -> str:
|
| 28 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 29 |
+
inputs = audio_processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
logits = audio_model(inputs.input_values).logits
|
| 32 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 33 |
+
transcription = audio_processor.batch_decode(predicted_ids)
|
| 34 |
+
return transcription[0]
|