Al1Abdullah commited on
Commit
593f0ea
·
0 Parent(s):

Initial commit for Hugging Face Space

Browse files
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ PINECONE_API_KEY = "pcsk_2dE5m6_B9WcbjZ1GcfT6p19rSXwMF2ULoqtc11xrXgngyhALBzmcrrxPLVM83xeKq537HX"
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from src.ollama_chain import OllamaChain, OllamaRAGChain
4
+ from src.llama_cpp_chains import LlamaChain
5
+ from src.pdf_handler import extract_pdf
6
+ from src.vqa import answer_visual_question
7
+ from src.audio_processor import AudioProcessor
8
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
9
+
10
+ from dotenv import load_dotenv
11
+ import os
12
+
13
+ load_dotenv()
14
+
15
+ audio_processor = AudioProcessor()
16
+
17
+
18
+ @st.cache_resource
19
+ def load_chain(_chat_memory):
20
+ if st.session_state.pdf_chat:
21
+ return OllamaRAGChain(_chat_memory)
22
+ else:
23
+ return OllamaChain(_chat_memory)
24
+
25
+
26
+ def file_uploader_change():
27
+ if st.session_state.uploaded_file:
28
+ if not st.session_state.pdf_chat:
29
+ clear_cache()
30
+ st.session_state.knowledge_change = True
31
+ else:
32
+ clear_cache()
33
+
34
+
35
+ def toggle_pdf_chat_change():
36
+ clear_cache()
37
+ if st.session_state.pdf_chat and st.session_state.uploaded_file:
38
+ st.session_state.knowledge_change = True
39
+
40
+
41
+ def clear_input_field():
42
+ st.session_state.user_question = st.session_state.user_input
43
+
44
+
45
+ def set_send_input():
46
+ st.session_state.send_input = True
47
+ clear_input_field()
48
+
49
+
50
+ def clear_cache():
51
+ st.cache_resource.clear()
52
+
53
+
54
+ def initial_session_state():
55
+ st.session_state.send_input = False
56
+ st.session_state.knowledge_change = False
57
+ st.session_state.user_question = ""
58
+
59
+
60
+ def main():
61
+ st.title('OVERKILL LLM')
62
+ os.makedirs('./.cache/temp_files', exist_ok=True) # Ensure temp folder exists
63
+ chat_container = st.container()
64
+
65
+ # Sidebar
66
+ st.sidebar.toggle('PDF Chat', value=False, key='pdf_chat', on_change=toggle_pdf_chat_change)
67
+ uploaded_pdf = st.sidebar.file_uploader(
68
+ 'Upload your pdf files',
69
+ type='pdf',
70
+ accept_multiple_files=True,
71
+ key='uploaded_file',
72
+ on_change=file_uploader_change
73
+ )
74
+
75
+ uploaded_image = st.sidebar.file_uploader('Upload Images', type=['jpg', 'jpeg', 'png'], key='uploaded_image')
76
+ st.sidebar.file_uploader('Upload Audio', type=['wav', 'mp3'], key='uploaded_audio')
77
+ uploaded_audio = st.session_state.get('uploaded_audio')
78
+
79
+ # Optional reset
80
+ if st.sidebar.button("🔄 Reset Chat"):
81
+ st.session_state.clear()
82
+ st.experimental_rerun()
83
+
84
+ # Input objects
85
+ user_input = st.text_input('Message OVERKILL', key='user_input', on_change=set_send_input)
86
+ send_button = st.button('Send', key='send_button')
87
+
88
+ # Initial session setup
89
+ if 'send_input' not in st.session_state or 'user_question' not in st.session_state:
90
+ initial_session_state()
91
+
92
+ chat_history = StreamlitChatMessageHistory(key='history')
93
+
94
+ with chat_container:
95
+ for msg in chat_history.messages:
96
+ st.chat_message(msg.type).write(msg.content)
97
+
98
+ try:
99
+ llm_chain = load_chain(chat_history)
100
+ except Exception as e:
101
+ st.error(f"Error loading LLM chain: {e}")
102
+ return
103
+
104
+ if st.session_state.knowledge_change:
105
+ with st.spinner('Updating knowledge base'):
106
+ try:
107
+ llm_chain.update_chain(uploaded_pdf)
108
+ st.session_state.knowledge_change = False
109
+ except Exception as e:
110
+ st.error(f"Error updating knowledge base: {e}")
111
+ return
112
+
113
+ if (send_button or st.session_state.send_input) and st.session_state.user_question != "":
114
+ with chat_container:
115
+ st.chat_message('user').write(st.session_state.user_question)
116
+
117
+ try:
118
+ if uploaded_image:
119
+ image_path = os.path.join('./.cache/temp_files', uploaded_image.name)
120
+ with open(image_path, 'wb') as f:
121
+ f.write(uploaded_image.getvalue())
122
+ llm_response = answer_visual_question(image_path, st.session_state.user_question)
123
+
124
+ elif uploaded_audio:
125
+ audio_path = os.path.join('./.cache/temp_files', uploaded_audio.name)
126
+ with open(audio_path, 'wb') as f:
127
+ f.write(uploaded_audio.getvalue())
128
+ st.write(f"Processing audio file: {audio_path}")
129
+ question = audio_processor.audio_to_text(audio_path)
130
+ st.write(f"Converted audio to text: {question}")
131
+ llm_response = llm_chain.run(user_input=question)
132
+
133
+ else:
134
+ llm_response = llm_chain.run(user_input=st.session_state.user_question)
135
+
136
+ st.session_state.user_question = ""
137
+ st.chat_message('ai').write(llm_response)
138
+
139
+ audio_file = audio_processor.text_to_speech(llm_response)
140
+ audio_bytes = open(audio_file, 'rb').read()
141
+ st.audio(audio_bytes, format='audio/mp3')
142
+
143
+ except Exception as e:
144
+ st.error(f"Error during chat: {e}")
145
+
146
+
147
+ if __name__ == '__main__':
148
+ main()
config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ chat_model:
2
+ 'model': "llama3:latest"
3
+ 'temperature': 0.75
4
+ 'num_gpu': 1
5
+ vector_database:
6
+ chroma:
7
+ chat_session_path: './chat_session/'
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ poppler-utils
2
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ torch
3
+ transformers
4
+ torchaudio
5
+ langchain
6
+ langchain-community
7
+ langchain-pinecone
8
+ langchain-chroma
9
+ pinecone-client
10
+ sentence-transformers
11
+ pypdf
12
+ PyMuPDF
13
+ pdf2image
14
+ pillow
15
+ opencv-python
16
+ ffmpeg-python
17
+ gtts
18
+ pydub
19
+ speechrecognition
20
+ streamlit>=1.32
21
+ google-generativeai
22
+ requests
23
+ python-dotenv
24
+ PyYAML
src/__pycache__/audio_processor.cpython-311.pyc ADDED
Binary file (3.15 kB). View file
 
src/__pycache__/llama_cpp_chains.cpython-311.pyc ADDED
Binary file (2.24 kB). View file
 
src/__pycache__/ollama_chain.cpython-311.pyc ADDED
Binary file (7.36 kB). View file
 
src/__pycache__/pdf_handler.cpython-311.pyc ADDED
Binary file (2.57 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (626 Bytes). View file
 
src/__pycache__/vectorstore.cpython-311.pyc ADDED
Binary file (4.94 kB). View file
 
src/__pycache__/vqa.cpython-311.pyc ADDED
Binary file (2.93 kB). View file
 
src/audio_processor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import speech_recognition as sr
2
+ from gtts import gTTS
3
+ import tempfile
4
+ from pydub import AudioSegment
5
+
6
+ class AudioProcessor:
7
+ def __init__(self):
8
+ self.recognizer = sr.Recognizer()
9
+
10
+ def audio_to_text(self, audio_file):
11
+ """Process an uploaded audio file and convert it to text."""
12
+ try:
13
+ # Convert audio file to WAV format
14
+ audio = AudioSegment.from_file(audio_file)
15
+ wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
16
+ audio.export(wav_file.name, format="wav")
17
+ print(f"Converted audio to WAV: {wav_file.name}") # Debug statement
18
+
19
+ with sr.AudioFile(wav_file.name) as source:
20
+ audio = self.recognizer.record(source)
21
+ try:
22
+ text = self.recognizer.recognize_google(audio)
23
+ except sr.UnknownValueError:
24
+ text = "Could not understand audio"
25
+ except sr.RequestError:
26
+ text = "Could not request results"
27
+ print(f"Recognized text: {text}") # Debug statement
28
+ return text
29
+ except Exception as e:
30
+ print(f"Error processing audio file: {e}") # Debug statement
31
+ return f"Error processing audio file: {e}"
32
+
33
+ def text_to_speech(self, text):
34
+ """Convert text to speech using gTTS and save as a .mp3 file."""
35
+ tts = gTTS(text)
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
37
+ tts.save(temp_audio.name)
38
+ return temp_audio.name
src/llama_cpp_chains.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import LlamaCpp
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.memory import ConversationBufferWindowMemory
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.runnables import RunnableSequence
6
+
7
+ from src.utils import load_config
8
+
9
+
10
+ class LlamaChain:
11
+ def __init__(self, chat_memory) -> None:
12
+ prompt = PromptTemplate(
13
+ template="""<|begin_of_text|>
14
+ <|start_header_id|>system<|end_header_id|>
15
+ You are a helpful and knowledgeable AI assistant.
16
+ <|eot_id|>
17
+ <|start_header_id|>user<|end_header_id|>
18
+ Previous conversation={chat_history}
19
+ Question: {input}
20
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
21
+ input_variables=['chat_history', 'input']
22
+ )
23
+
24
+ self.memory = ConversationBufferWindowMemory(
25
+ memory_key='chat_history',
26
+ chat_memory=chat_memory,
27
+ k=3,
28
+ return_messages=True
29
+ )
30
+
31
+ config = load_config()
32
+ llm = LlamaCpp(**config['chat_model'])
33
+
34
+ self.llm_chain = RunnableSequence(prompt | llm | self.memory | StrOutputParser())
35
+
36
+ def run(self, user_input):
37
+ response = self.llm_chain.invoke(user_input)
38
+ return response['text']
src/ollama_chain.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import Ollama
2
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
3
+ from langchain.memory import ConversationBufferWindowMemory
4
+ from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables.history import RunnableWithMessageHistory
8
+ from langchain.schema import Document
9
+
10
+ from src.utils import load_config
11
+ from src.vectorstore import VectorDB
12
+
13
+
14
+ def format_docs(docs: list[Document]):
15
+ return '\n\n'.join(doc.page_content for doc in docs)
16
+
17
+
18
+ class OllamaChain:
19
+ def __init__(self, chat_memory) -> None:
20
+ prompt = PromptTemplate(
21
+ template="""<|begin_of_text|>
22
+ <|start_header_id|>system<|end_header_id|>
23
+ You are a honest and unbiased AI assistant
24
+ <|eot_id|>
25
+ <|start_header_id|>user<|end_header_id|>
26
+ Previous conversation={chat_history}
27
+ Question: {input}
28
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
29
+ input_variables=['chat_history', 'input']
30
+ )
31
+
32
+ self.memory = ConversationBufferWindowMemory(
33
+ memory_key='chat_history',
34
+ chat_memory=chat_memory,
35
+ k=3,
36
+ return_messages=True
37
+ )
38
+
39
+ config = load_config()
40
+ llm = Ollama(**config['chat_model'])
41
+ # llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1)
42
+
43
+ self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser())
44
+ # runnable = prompt | llm
45
+
46
+ def run(self, user_input):
47
+ response = self.llm_chain.invoke(user_input)
48
+
49
+ return response['text']
50
+
51
+
52
+ class OllamaRAGChain:
53
+ def __init__(self, chat_memory, uploaded_file=None):
54
+ # initialize vector db using config
55
+ from src.utils import load_config
56
+ config = load_config()
57
+ vector_db_config = config.get('vector_database', {})
58
+ db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
59
+ index_name = 'default'
60
+ self.vector_db = VectorDB(db_name, index_name)
61
+ if uploaded_file:
62
+ self.update_knowledge_base(uploaded_file)
63
+
64
+ # initialize llm
65
+ config = load_config()
66
+ self.llm = Ollama(**config['chat_model'])
67
+
68
+ # initialize memory
69
+ self.chat_memory = chat_memory
70
+
71
+ # initialize sub chain with history message
72
+ contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \
73
+ in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \
74
+ standalone question which is incorporated from the latest question and history and can be understood without \
75
+ the chat history.
76
+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
77
+
78
+ self.contextual_q_prompt = ChatPromptTemplate.from_messages(
79
+ [
80
+ ('system', contextual_q_system_prompt),
81
+ MessagesPlaceholder('chat_history'),
82
+ ('human', '{input}'),
83
+ ]
84
+ )
85
+
86
+ self.history_aware_retriever = create_history_aware_retriever(
87
+ self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
88
+ )
89
+
90
+ # initialize qa chain
91
+ qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\
92
+ context to answer the question. If you don't know the answer, just say that you don't know.
93
+ Context: {context}"""
94
+ qa_prompt = ChatPromptTemplate.from_messages(
95
+ [
96
+ ('system', qa_system_prompt),
97
+ MessagesPlaceholder('chat_history'),
98
+ ('human', '{input}'),
99
+ ]
100
+ )
101
+
102
+ self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
103
+
104
+ rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)
105
+
106
+ self.conversation_rag_chain = RunnableWithMessageHistory(
107
+ rag_chain,
108
+ lambda session_id: chat_memory,
109
+ input_messages_key='input',
110
+ history_messages_key='chat_history',
111
+ output_messages_key='answer'
112
+ )
113
+
114
+ def run(self, user_input):
115
+ config = {"configurable": {"session_id": "any"}}
116
+ response = self.conversation_rag_chain.invoke({'input': user_input}, config)
117
+
118
+ return response['answer']
119
+
120
+ def update_chain(self, uploaded_pdf):
121
+ self.update_knowledge_base(uploaded_pdf)
122
+ self.history_aware_retriever = create_history_aware_retriever(
123
+ self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
124
+ )
125
+ self.conversation_rag_chain = RunnableWithMessageHistory(
126
+ create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain),
127
+ lambda session_id: self.chat_memory,
128
+ input_messages_key='input',
129
+ history_messages_key='chat_history',
130
+ output_messages_key='answer'
131
+ )
132
+
133
+ def update_knowledge_base(self, uploaded_pdf):
134
+ self.vector_db.index(uploaded_pdf)
src/pdf_handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_community.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain.schema.document import Document
6
+
7
+
8
+ def create_cache_dir(directory=None):
9
+ if not directory:
10
+ directory = './.cache'
11
+
12
+ os.makedirs('./.cache', exist_ok=True)
13
+ return directory
14
+
15
+
16
+ def load_pdf(file_path):
17
+ loader = PyPDFLoader(file_path)
18
+
19
+ return loader.load()
20
+
21
+
22
+ def load_pdf_directory(directory):
23
+ loader = PyPDFDirectoryLoader(directory)
24
+ return loader.load()
25
+
26
+
27
+ def split_pdf(pdfs: list[Document]):
28
+ splitter = RecursiveCharacterTextSplitter(
29
+ chunk_size=512,
30
+ chunk_overlap=64,
31
+ length_function=len,
32
+ is_separator_regex=False
33
+ )
34
+
35
+ return splitter.split_documents(pdfs)
36
+
37
+
38
+ def extract_pdf(uploaded_pdf):
39
+ cache_dir = create_cache_dir()
40
+ cache_dir = os.path.join(cache_dir, 'temp_files')
41
+ os.makedirs(cache_dir, exist_ok=True)
42
+ # Support both single file and list of files
43
+ if not isinstance(uploaded_pdf, list):
44
+ uploaded_pdf = [uploaded_pdf]
45
+ for file in uploaded_pdf:
46
+ file_path = os.path.join(cache_dir, file.name)
47
+ with open(file_path, 'wb') as w:
48
+ w.write(file.getvalue())
49
+ return cache_dir
src/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import yaml
3
+
4
+
5
+ def load_config():
6
+ with open('./config.yaml', 'r') as f:
7
+ config = yaml.safe_load(f)
8
+
9
+ return config
src/vectorstore.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone, ServerlessSpec, PodSpec
2
+ from langchain_pinecone import PineconeVectorStore
3
+ from langchain_chroma import Chroma
4
+ from langchain_community.embeddings import OllamaEmbeddings
5
+ from langchain.indexes import SQLRecordManager, index
6
+
7
+ from src.pdf_handler import extract_pdf, load_pdf_directory, split_pdf
8
+ from src.utils import load_config
9
+
10
+ import os
11
+ import shutil
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+
17
+ def setup_pinecone(index_name, embedding_model, embedding_dim, metric='cosine', use_serverless=True):
18
+ pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
19
+ if use_serverless:
20
+ spec = ServerlessSpec(cloud='aws', region='us-east-1')
21
+ else:
22
+ spec = PodSpec()
23
+
24
+ if index_name in pc.list_indexes().names():
25
+ pc.delete_index(index_name)
26
+
27
+ pc.create_index(
28
+ index_name,
29
+ dimension=embedding_dim,
30
+ metric=metric,
31
+ spec=spec
32
+ )
33
+
34
+ db = PineconeVectorStore(index_name=index_name, embedding=embedding_model)
35
+ return db
36
+
37
+
38
+ def setup_chroma(index_name, embedding_model, persist_directory=None):
39
+ if not persist_directory:
40
+ persist_directory = './.cache/database'
41
+
42
+ os.makedirs(persist_directory, exist_ok=True)
43
+
44
+ db = Chroma(index_name, embedding_function=embedding_model, persist_directory=persist_directory)
45
+ return db
46
+
47
+
48
+ class VectorDB:
49
+ def __init__(self, db_name=None, index_name=None, cache_dir=None):
50
+ config = load_config()
51
+ vector_db_config = config.get('vector_database', {})
52
+ # Determine DB type from config, fallback to argument or chroma
53
+ if db_name is None:
54
+ db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
55
+ if index_name is None:
56
+ index_name = 'default'
57
+ embedding = OllamaEmbeddings(model='nomic-embed-text:latest', num_gpu=1)
58
+ if not cache_dir:
59
+ cache_dir = './.cache/database'
60
+ self.cache_dir = cache_dir
61
+ os.makedirs(self.cache_dir, exist_ok=True)
62
+ if db_name == 'pinecone':
63
+ if not os.environ.get('PINECONE_API_KEY'):
64
+ raise ValueError("PINECONE_API_KEY environment variable is not set. Please set it in your .env file or environment.")
65
+ self.vectorstore = setup_pinecone(index_name, embedding, 768, 'cosine')
66
+ else:
67
+ self.vectorstore = setup_chroma(index_name, embedding, self.cache_dir)
68
+ namespace = f'{db_name}/{index_name}'
69
+ self.record_manager = SQLRecordManager(namespace,
70
+ db_url=f'sqlite:///{self.cache_dir}/record_manager_cache.sql')
71
+ self.record_manager.create_schema()
72
+
73
+ def index(self, uploaded_file):
74
+ directory = extract_pdf(uploaded_file)
75
+ docs = load_pdf_directory(directory)
76
+ chunks = split_pdf(docs)
77
+
78
+ index(
79
+ docs_source=chunks,
80
+ record_manager=self.record_manager,
81
+ vector_store=self.vectorstore,
82
+ cleanup='full',
83
+ source_id_key='source'
84
+ )
85
+
86
+ for file in os.listdir(directory):
87
+ os.remove(os.path.join(directory, file))
88
+
89
+ def as_retriever(self):
90
+ return self.vectorstore.as_retriever()
src/vqa.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
3
+ warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")
4
+
5
+ import requests
6
+ from PIL import Image
7
+ from transformers import BlipProcessor, BlipForQuestionAnswering, Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ import os
9
+ import torchaudio
10
+
11
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
+ model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
13
+
14
+ audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
15
+ audio_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
16
+
17
+ def answer_visual_question(image_path_or_url: str, question: str) -> str:
18
+ if os.path.isfile(image_path_or_url):
19
+ raw_image = Image.open(image_path_or_url).convert('RGB')
20
+ else:
21
+ raw_image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
22
+
23
+ inputs = processor(raw_image, question, return_tensors="pt")
24
+ out = model.generate(**inputs)
25
+ return processor.decode(out[0], skip_special_tokens=True)
26
+
27
+ def transcribe_audio(audio_path: str) -> str:
28
+ waveform, sample_rate = torchaudio.load(audio_path)
29
+ inputs = audio_processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
30
+ with torch.no_grad():
31
+ logits = audio_model(inputs.input_values).logits
32
+ predicted_ids = torch.argmax(logits, dim=-1)
33
+ transcription = audio_processor.batch_decode(predicted_ids)
34
+ return transcription[0]