Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- .gitattributes +2 -35
- documents.pkl +3 -0
- faiss_index/.gitattributes +2 -0
- faiss_index/documents.pkl +3 -0
- faiss_index/index.faiss +3 -0
- src/app.py +211 -0
- src/chatbot.py +165 -0
- src/document_processor.py +167 -0
- src/vector_store.py +195 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
documents.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf9c7941cca8d55333bc9d1c0232934f0a8edf7bb17219a728acd6e6476fd897
|
| 3 |
+
size 2235712
|
faiss_index/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
faiss_index/documents.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf9c7941cca8d55333bc9d1c0232934f0a8edf7bb17219a728acd6e6476fd897
|
| 3 |
+
size 2235712
|
faiss_index/index.faiss
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35d72c5b12d9f320e9cfd1836133df9514417bec14ddb4ab7937d746886c5abf
|
| 3 |
+
size 29995053
|
src/app.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from chatbot import BioethicsChatbot
|
| 4 |
+
import time
|
| 5 |
+
import io
|
| 6 |
+
import sys
|
| 7 |
+
from contextlib import redirect_stdout, redirect_stderr
|
| 8 |
+
|
| 9 |
+
st.set_page_config(
|
| 10 |
+
page_title="Bioethics AI Assistant",
|
| 11 |
+
page_icon="🧬",
|
| 12 |
+
layout="wide"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
st.title("🧬 Bioethics AI Assistant")
|
| 16 |
+
st.markdown("*Ask questions about medical ethics, informed consent, research ethics, and more*")
|
| 17 |
+
|
| 18 |
+
# Custom CSS to hide debug output
|
| 19 |
+
st.markdown("""
|
| 20 |
+
<style>
|
| 21 |
+
.debug-output {
|
| 22 |
+
background-color: #f0f0f0;
|
| 23 |
+
padding: 10px;
|
| 24 |
+
border-radius: 5px;
|
| 25 |
+
font-family: monospace;
|
| 26 |
+
font-size: 12px;
|
| 27 |
+
color: #666;
|
| 28 |
+
}
|
| 29 |
+
</style>
|
| 30 |
+
""", unsafe_allow_html=True)
|
| 31 |
+
|
| 32 |
+
# Sidebar info
|
| 33 |
+
with st.sidebar:
|
| 34 |
+
st.markdown("### About")
|
| 35 |
+
st.write("This demo uses Retrieval-Augmented Generation (RAG) with open-access bioethics papers.")
|
| 36 |
+
|
| 37 |
+
st.markdown("### Sample Questions")
|
| 38 |
+
sample_questions = [
|
| 39 |
+
"What is informed consent in medical research?",
|
| 40 |
+
"What are the ethical issues with genetic testing?",
|
| 41 |
+
"How should AI bias in healthcare be addressed?",
|
| 42 |
+
"What is the principle of beneficence?",
|
| 43 |
+
"What are the ethics of end-of-life care?"
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
for q in sample_questions:
|
| 47 |
+
if st.button(q, key=q, use_container_width=True):
|
| 48 |
+
st.session_state.current_question = q
|
| 49 |
+
|
| 50 |
+
st.markdown("---")
|
| 51 |
+
st.markdown("### Demo Info")
|
| 52 |
+
st.info("💡 This demo shows sources found and similarity scores for transparency")
|
| 53 |
+
|
| 54 |
+
# Rate limiting
|
| 55 |
+
if 'query_count' not in st.session_state:
|
| 56 |
+
st.session_state.query_count = 0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Initialize chatbot (only once)
|
| 60 |
+
@st.cache_resource
|
| 61 |
+
def load_chatbot():
|
| 62 |
+
"""Load chatbot once and cache it"""
|
| 63 |
+
return BioethicsChatbot("data/")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Main interface
|
| 67 |
+
col1, col2 = st.columns([4, 1])
|
| 68 |
+
|
| 69 |
+
with col1:
|
| 70 |
+
question = st.text_input(
|
| 71 |
+
"Your question:",
|
| 72 |
+
value=st.session_state.get('current_question', ''),
|
| 73 |
+
placeholder="e.g., What are the ethical considerations in clinical trials?",
|
| 74 |
+
key="question_input"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
with col2:
|
| 78 |
+
st.metric("Queries Used", f"{st.session_state.query_count}/50")
|
| 79 |
+
|
| 80 |
+
# Clear the current_question after it's been used
|
| 81 |
+
if 'current_question' in st.session_state:
|
| 82 |
+
del st.session_state.current_question
|
| 83 |
+
|
| 84 |
+
if question and st.session_state.query_count < 50:
|
| 85 |
+
|
| 86 |
+
# Load chatbot
|
| 87 |
+
try:
|
| 88 |
+
if 'bot' not in st.session_state:
|
| 89 |
+
with st.spinner("🔄 Loading bioethics knowledge base..."):
|
| 90 |
+
st.session_state.bot = load_chatbot()
|
| 91 |
+
st.success("✅ Knowledge base loaded!")
|
| 92 |
+
|
| 93 |
+
st.session_state.query_count += 1
|
| 94 |
+
|
| 95 |
+
# Create columns for response
|
| 96 |
+
response_col, debug_col = st.columns([2, 1])
|
| 97 |
+
|
| 98 |
+
with response_col:
|
| 99 |
+
st.markdown("### 🤖 Assistant Response")
|
| 100 |
+
|
| 101 |
+
# Capture the streaming output and debug info
|
| 102 |
+
start_time = time.time()
|
| 103 |
+
|
| 104 |
+
# Capture stdout to get debug prints
|
| 105 |
+
old_stdout = sys.stdout
|
| 106 |
+
old_stderr = sys.stderr
|
| 107 |
+
stdout_capture = io.StringIO()
|
| 108 |
+
stderr_capture = io.StringIO()
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Redirect prints to capture debug info
|
| 112 |
+
sys.stdout = stdout_capture
|
| 113 |
+
sys.stderr = stderr_capture
|
| 114 |
+
|
| 115 |
+
# Get the answer (this will stream to captured stdout)
|
| 116 |
+
answer = st.session_state.bot.ask(question)
|
| 117 |
+
|
| 118 |
+
finally:
|
| 119 |
+
# Restore stdout/stderr
|
| 120 |
+
sys.stdout = old_stdout
|
| 121 |
+
sys.stderr = old_stderr
|
| 122 |
+
|
| 123 |
+
response_time = time.time() - start_time
|
| 124 |
+
|
| 125 |
+
# Display the final answer
|
| 126 |
+
st.write(answer)
|
| 127 |
+
|
| 128 |
+
with debug_col:
|
| 129 |
+
st.markdown("### 🔍 Debug Info")
|
| 130 |
+
|
| 131 |
+
# Show search results info
|
| 132 |
+
if 'bot' in st.session_state:
|
| 133 |
+
# Get search results for debug display
|
| 134 |
+
search_results = st.session_state.bot.vector_store.search(question, k=3)
|
| 135 |
+
with st.expander("📊 Search Results", expanded=True):
|
| 136 |
+
for i, r in enumerate(search_results):
|
| 137 |
+
st.write(f"**Result {i + 1}** (Score: {r.get('similarity_score', 0):.3f})")
|
| 138 |
+
st.write(f"Source: {r['metadata'].get('filename', 'Unknown')}")
|
| 139 |
+
st.write(f"Preview: {r['content'][:200]}...")
|
| 140 |
+
st.write("---")
|
| 141 |
+
|
| 142 |
+
# Show response metadata
|
| 143 |
+
st.metric("Response Time", f"{response_time:.2f}s")
|
| 144 |
+
st.metric("Model", "GPT-4o-mini")
|
| 145 |
+
|
| 146 |
+
# Show conversation history count
|
| 147 |
+
if hasattr(st.session_state.bot, 'history'):
|
| 148 |
+
st.metric("Conversation Turn", len(st.session_state.bot.history))
|
| 149 |
+
|
| 150 |
+
# Show source information
|
| 151 |
+
with st.expander("📚 About the Sources"):
|
| 152 |
+
st.markdown("""
|
| 153 |
+
This assistant searches through open-access bioethics papers to find relevant information.
|
| 154 |
+
|
| 155 |
+
**Search Process:**
|
| 156 |
+
1. Your question is converted to embeddings
|
| 157 |
+
2. Similar text chunks are found using FAISS vector search
|
| 158 |
+
3. Only chunks with similarity score ≥ 0.65 are used for citations
|
| 159 |
+
4. The language model synthesizes an answer from these sources
|
| 160 |
+
""")
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
st.error(f"❌ Error: {str(e)}")
|
| 164 |
+
st.info("Please try refreshing the page or try a different question.")
|
| 165 |
+
|
| 166 |
+
elif st.session_state.query_count >= 50:
|
| 167 |
+
st.error("📈 Demo limit reached for today. This prevents API abuse.")
|
| 168 |
+
st.info("💡 For unlimited use, clone the repository and use your own API key.")
|
| 169 |
+
|
| 170 |
+
with st.expander("🚀 How to run locally"):
|
| 171 |
+
st.code("""
|
| 172 |
+
# Clone the repository
|
| 173 |
+
git clone your-repo-url
|
| 174 |
+
cd bioethics-chatbot
|
| 175 |
+
|
| 176 |
+
# Install dependencies
|
| 177 |
+
pip install -r requirements.txt
|
| 178 |
+
|
| 179 |
+
# Set your OpenAI API key
|
| 180 |
+
export OPENAI_API_KEY="your-key-here"
|
| 181 |
+
|
| 182 |
+
# Run locally
|
| 183 |
+
streamlit run app.py
|
| 184 |
+
""", language="bash")
|
| 185 |
+
|
| 186 |
+
# Footer
|
| 187 |
+
st.markdown("---")
|
| 188 |
+
col1, col2, col3 = st.columns(3)
|
| 189 |
+
|
| 190 |
+
with col1:
|
| 191 |
+
st.markdown("**🔗 Links**")
|
| 192 |
+
st.markdown("- [GitHub Repository](your-repo-link)")
|
| 193 |
+
st.markdown("- [Open Source Papers Used](./data/LICENSE_INFO.md)")
|
| 194 |
+
|
| 195 |
+
with col2:
|
| 196 |
+
st.markdown("**🛠️ Tech Stack**")
|
| 197 |
+
st.markdown("- Python & Streamlit")
|
| 198 |
+
st.markdown("- OpenAI GPT-4o-mini")
|
| 199 |
+
st.markdown("- FAISS Vector Search")
|
| 200 |
+
st.markdown("- LangChain")
|
| 201 |
+
|
| 202 |
+
with col3:
|
| 203 |
+
st.markdown("**📊 Demo Stats**")
|
| 204 |
+
if 'bot' in st.session_state and hasattr(st.session_state.bot, 'vector_store'):
|
| 205 |
+
doc_count = len(st.session_state.bot.vector_store.documents)
|
| 206 |
+
st.markdown(f"- {doc_count} text chunks indexed")
|
| 207 |
+
st.markdown(f"- Vector dimension: {st.session_state.bot.vector_store.dimension}")
|
| 208 |
+
st.markdown(f"- Queries today: {st.session_state.query_count}")
|
| 209 |
+
|
| 210 |
+
# Add some spacing
|
| 211 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
src/chatbot.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from document_processor import DocumentProcessor
|
| 2 |
+
from vector_store import FAISSVectorStore
|
| 3 |
+
from langchain_openai import ChatOpenAI
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=logging.INFO)
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
| 12 |
+
|
| 13 |
+
class StreamHandler(BaseCallbackHandler):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.current_text = ""
|
| 16 |
+
|
| 17 |
+
def on_llm_new_token(self, token: str, **kwargs):
|
| 18 |
+
print(token, end="", flush=True) # stream to console
|
| 19 |
+
self.current_text += token
|
| 20 |
+
|
| 21 |
+
def get_text(self):
|
| 22 |
+
return self.current_text
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BioethicsChatbot:
|
| 26 |
+
def __init__(self, data_dir: str="data/sample_papers"):
|
| 27 |
+
self.processor = DocumentProcessor()
|
| 28 |
+
self.vector_store = FAISSVectorStore()
|
| 29 |
+
self.history = []
|
| 30 |
+
self.confidence_thresholds = {
|
| 31 |
+
'high': 0.8,
|
| 32 |
+
'medium': 0.65,
|
| 33 |
+
'low': 0.5}
|
| 34 |
+
|
| 35 |
+
if not self.vector_store.load_index():
|
| 36 |
+
print("No existing vector store, creating one...")
|
| 37 |
+
pdf_files = list(Path(data_dir).glob("*.pdf"))
|
| 38 |
+
if not pdf_files:
|
| 39 |
+
raise ValueError(f"No PDFs found in {data_dir}")
|
| 40 |
+
|
| 41 |
+
chunks = self.processor.process_documents([str(p) for p in pdf_files])
|
| 42 |
+
self.vector_store.add_documents(chunks)
|
| 43 |
+
logger.info("Indexed %d documents.", len(chunks))
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
logger.info("Index loaded from disk")
|
| 47 |
+
|
| 48 |
+
self.stream_handler = StreamHandler()
|
| 49 |
+
self.llm = ChatOpenAI(model="gpt-4o-mini", streaming=True,
|
| 50 |
+
callbacks=[self.stream_handler])
|
| 51 |
+
|
| 52 |
+
def add_new_document(self, pdf_path: str):
|
| 53 |
+
filename = Path(pdf_path).name
|
| 54 |
+
|
| 55 |
+
# Check if already in the index
|
| 56 |
+
existing_files = {doc["metadata"].get("filename") for doc in self.vector_store.documents}
|
| 57 |
+
if filename in existing_files:
|
| 58 |
+
print(f"Skipping {filename}: already indexed.")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
# Otherwise process & add
|
| 62 |
+
chunks = self.processor.process_document(pdf_path)
|
| 63 |
+
self.vector_store.add_documents(chunks)
|
| 64 |
+
print(f"Added {len(chunks)} chunks from {pdf_path}")
|
| 65 |
+
|
| 66 |
+
def get_citation_confidence(self, similarity_score: float) -> str:
|
| 67 |
+
"""Determine citation confidence level based on similarity score"""
|
| 68 |
+
if similarity_score >= self.confidence_thresholds['high']:
|
| 69 |
+
return "high_confidence"
|
| 70 |
+
elif similarity_score >= self.confidence_thresholds['medium']:
|
| 71 |
+
return "medium_confidence"
|
| 72 |
+
elif similarity_score >= self.confidence_thresholds['low']:
|
| 73 |
+
return "low_confidence"
|
| 74 |
+
return "context_only"
|
| 75 |
+
|
| 76 |
+
def ask(self, question: str, k: int = 10) -> str:
|
| 77 |
+
# Step 1: Retrieve relevant chunks
|
| 78 |
+
results = self.vector_store.search(question, k=k)
|
| 79 |
+
|
| 80 |
+
# DEBUG: Print what we found
|
| 81 |
+
print(f"Found {len(results)} results for query: '{question}'")
|
| 82 |
+
for i, r in enumerate(results[:3]): # Show top 3
|
| 83 |
+
print(f"Result {i + 1} (score: {r.get('similarity_score', 'N/A'):.3f}): {r['content'][:200]}...")
|
| 84 |
+
|
| 85 |
+
if not results:
|
| 86 |
+
return "I couldn't find relevant information in the documents."
|
| 87 |
+
|
| 88 |
+
# Step 2: Build context from retrieved chunks
|
| 89 |
+
context_blocks = []
|
| 90 |
+
citation_groups = {
|
| 91 |
+
'high_confidence': [],
|
| 92 |
+
'medium_confidence': [],
|
| 93 |
+
'low_confidence': [],
|
| 94 |
+
}
|
| 95 |
+
for r in results:
|
| 96 |
+
title = r["metadata"].get("title", None)
|
| 97 |
+
authors = r["metadata"].get("authors", None)
|
| 98 |
+
year = r["metadata"].get("year", "n.d.")
|
| 99 |
+
|
| 100 |
+
confidence = self.get_citation_confidence(r["similarity_score"])
|
| 101 |
+
|
| 102 |
+
block = (
|
| 103 |
+
f"Source: {authors} ({year}). *{title}* "
|
| 104 |
+
f"[chunk {r['metadata'].get('chunk_id', '?')}, confidence: {confidence}]\n"
|
| 105 |
+
f"{r['content']}\n"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
context_blocks.append(block)
|
| 109 |
+
if authors is not None and authors != "Unknown Author(s)":
|
| 110 |
+
citation_groups[confidence].append(block)
|
| 111 |
+
|
| 112 |
+
history_text = "\n".join(
|
| 113 |
+
[f"User: {u}\nBot: {b}" for u, b in self.history[-4:]]
|
| 114 |
+
) or "No previous conversation."
|
| 115 |
+
context = f"""
|
| 116 |
+
Conversation so far:
|
| 117 |
+
{history_text}
|
| 118 |
+
|
| 119 |
+
Relevant sources (use them to guide your answer, but cite only the ones in citation groups):
|
| 120 |
+
{"\n\n".join(context_blocks)}
|
| 121 |
+
|
| 122 |
+
Do not cite if the author is "Unknown Author(s)".
|
| 123 |
+
CITATION GUIDELINES:
|
| 124 |
+
- HIGH CONFIDENCE sources: Use direct citations "(Author, Year)"
|
| 125 |
+
- MEDIUM CONFIDENCE sources: Use "According to Author (Year)..."
|
| 126 |
+
- LOW CONFIDENCE sources: Use "(see Author, Year)"
|
| 127 |
+
|
| 128 |
+
High confidence sources:
|
| 129 |
+
{"\n\n".join(citation_groups['high_confidence']) or "None"}
|
| 130 |
+
|
| 131 |
+
Medium confidence sources:
|
| 132 |
+
{"\n\n".join(citation_groups['medium_confidence']) or "None"}
|
| 133 |
+
|
| 134 |
+
Low confidence sources:
|
| 135 |
+
{"\n\n".join(citation_groups['low_confidence']) or "None"}
|
| 136 |
+
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# Step 3: Construct prompt
|
| 140 |
+
prompt = f"""
|
| 141 |
+
You are a bioethics expert assistant.
|
| 142 |
+
Answer the user's question using the context provided below.
|
| 143 |
+
Draw justified connections between concepts even if not explicitly stated.
|
| 144 |
+
If you need to make reasonable inferences based on the context, do so.
|
| 145 |
+
If the context doesn't contain enough information, say what you do know from the context and indicate what information is missing.
|
| 146 |
+
If the question doesn't concern neither bioethics nor previous questions, inform the user about it and don't answer it.
|
| 147 |
+
Context:
|
| 148 |
+
{context}
|
| 149 |
+
|
| 150 |
+
Question: {question}
|
| 151 |
+
Answer:
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
self.stream_handler.current_text = ""
|
| 155 |
+
|
| 156 |
+
_ = self.llm.invoke(prompt) # streaming happens here
|
| 157 |
+
print() # newline after streaming
|
| 158 |
+
|
| 159 |
+
answer = self.stream_handler.get_text()
|
| 160 |
+
self.history.append((question, answer))
|
| 161 |
+
|
| 162 |
+
return answer
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
src/document_processor.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import logging
|
| 6 |
+
import PyPDF2
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=logging.INFO)
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class DocumentProcessor:
|
| 12 |
+
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
|
| 13 |
+
self.chunk_size = chunk_size
|
| 14 |
+
self.chunk_overlap = chunk_overlap
|
| 15 |
+
|
| 16 |
+
def extract_text_from_pdf(self,pdf_path: str) -> str:
|
| 17 |
+
"""Extract text from PDF file"""
|
| 18 |
+
try:
|
| 19 |
+
doc = fitz.open(pdf_path)
|
| 20 |
+
text = ""
|
| 21 |
+
|
| 22 |
+
for page in doc:
|
| 23 |
+
text += page.get_text()
|
| 24 |
+
text += f"\n--- Page {page.number + 1} ---\n" # page.number is 0-indexed
|
| 25 |
+
|
| 26 |
+
logger.info(f"Extracted text from {pdf_path}: {len(text)} characters, {len(doc)} pages")
|
| 27 |
+
doc.close()
|
| 28 |
+
return text
|
| 29 |
+
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Error extracting text from {pdf_path}: {e}")
|
| 32 |
+
return ""
|
| 33 |
+
|
| 34 |
+
def clean_text(self,text: str) -> str:
|
| 35 |
+
"""Clean text from PDF"""
|
| 36 |
+
text = re.sub(r'\n{2,}', '\n', text) # keep single newlines
|
| 37 |
+
text = re.sub(r'[ \t]+', ' ', text) # collapse spaces/tabs
|
| 38 |
+
|
| 39 |
+
# Remove page headers/footers
|
| 40 |
+
text = re.sub(r'Page \d+.*?\n', '', text)
|
| 41 |
+
|
| 42 |
+
# Remove references to figures/tables
|
| 43 |
+
text = re.sub(r'\[Figure \d+\]|\[Table \d+\]', '', text)
|
| 44 |
+
|
| 45 |
+
return text.strip()
|
| 46 |
+
|
| 47 |
+
def chunk_text(self,text: str, metadata: Dict = None) -> List[Dict]:
|
| 48 |
+
"""Split text into chunks with metadata"""
|
| 49 |
+
if not text:
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
sentences = text.split('. ')
|
| 53 |
+
chunks = []
|
| 54 |
+
current_chunk = ""
|
| 55 |
+
|
| 56 |
+
for sentence in sentences:
|
| 57 |
+
# If adding this sentence would exceed chunk size
|
| 58 |
+
if len(current_chunk) + len(sentence) > self.chunk_size:
|
| 59 |
+
if current_chunk:
|
| 60 |
+
chunks.append({
|
| 61 |
+
"text": current_chunk.strip(),
|
| 62 |
+
"metadata": metadata or {},
|
| 63 |
+
"chunk_id": len(chunks)
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
# Start new chunk with overlap
|
| 67 |
+
overlap_text = current_chunk[-self.chunk_overlap:] if len(
|
| 68 |
+
current_chunk) > self.chunk_overlap else current_chunk
|
| 69 |
+
current_chunk = overlap_text + " " + sentence
|
| 70 |
+
else:
|
| 71 |
+
current_chunk = sentence
|
| 72 |
+
else:
|
| 73 |
+
current_chunk += ". " + sentence if current_chunk else sentence
|
| 74 |
+
|
| 75 |
+
# Add final chunk
|
| 76 |
+
if current_chunk:
|
| 77 |
+
chunks.append({
|
| 78 |
+
"text": current_chunk.strip(),
|
| 79 |
+
"metadata": metadata or {},
|
| 80 |
+
"chunk_id": len(chunks)
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
logger.info(f"Created {len(chunks)} chunks")
|
| 84 |
+
return chunks
|
| 85 |
+
|
| 86 |
+
def extract_metadata(self, pdf_path: str) -> dict:
|
| 87 |
+
"""Extract metadata (title, authors, year, filename, file_size) from a PDF."""
|
| 88 |
+
|
| 89 |
+
metadata = {
|
| 90 |
+
"filename": Path(pdf_path).name,
|
| 91 |
+
"file_size": Path(pdf_path).stat().st_size,
|
| 92 |
+
"title": None,
|
| 93 |
+
"authors": None,
|
| 94 |
+
"year": None
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
with open(pdf_path, "rb") as f:
|
| 98 |
+
reader = PyPDF2.PdfReader(f)
|
| 99 |
+
|
| 100 |
+
# 1. Try embedded PDF metadata
|
| 101 |
+
pdf_meta = reader.metadata
|
| 102 |
+
if pdf_meta:
|
| 103 |
+
title = pdf_meta.get("/Title", "").strip()
|
| 104 |
+
author = pdf_meta.get("/Author", "").strip()
|
| 105 |
+
|
| 106 |
+
if title and title.lower() not in ["", "untitled", "unknown"]:
|
| 107 |
+
metadata["title"] = title
|
| 108 |
+
|
| 109 |
+
if author and author.lower() not in ["", "anonymous", "unknown"]:
|
| 110 |
+
metadata["authors"] = author
|
| 111 |
+
|
| 112 |
+
# 2. Fallback: look at first page
|
| 113 |
+
if not metadata["title"] or not metadata["authors"]:
|
| 114 |
+
try:
|
| 115 |
+
first_page = reader.pages[0].extract_text() or ""
|
| 116 |
+
lines = [line.strip() for line in first_page.split("\n") if line.strip()]
|
| 117 |
+
|
| 118 |
+
# crude heuristic: first line = title
|
| 119 |
+
if not metadata["title"] and lines:
|
| 120 |
+
metadata["title"] = lines[0]
|
| 121 |
+
|
| 122 |
+
# crude heuristic: authors in line(s) after title
|
| 123 |
+
if not metadata["authors"] and len(lines) > 1:
|
| 124 |
+
possible_authors = lines[1]
|
| 125 |
+
if re.search(r"[A-Z][a-z]+(?: [A-Z][a-z]+)*", possible_authors):
|
| 126 |
+
metadata["authors"] = possible_authors
|
| 127 |
+
|
| 128 |
+
# crude heuristic: find year (e.g., 2023, 2024)
|
| 129 |
+
year_match = re.search(r"\b(19|20)\d{2}\b", first_page)
|
| 130 |
+
if year_match:
|
| 131 |
+
metadata["year"] = year_match.group(0)
|
| 132 |
+
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
# Defaults if missing
|
| 137 |
+
metadata["title"] = metadata["title"] or "Unknown Title"
|
| 138 |
+
metadata["authors"] = metadata["authors"] if metadata["authors"] else None
|
| 139 |
+
metadata["year"] = metadata["year"] or "n.d."
|
| 140 |
+
|
| 141 |
+
return metadata
|
| 142 |
+
|
| 143 |
+
def process_document(self,pdf_path: str) -> List[Dict]:
|
| 144 |
+
"""Complete document processing"""
|
| 145 |
+
try:
|
| 146 |
+
file_path = Path(pdf_path)
|
| 147 |
+
|
| 148 |
+
except TypeError as e: # Catches specifically if pdf_path is the wrong type
|
| 149 |
+
logger.error(f"Invalid path type: {pdf_path}: {e}")
|
| 150 |
+
raise
|
| 151 |
+
except OSError as e: # Catches other filesystem-related errors
|
| 152 |
+
logger.error(f"OS error with path: {pdf_path}: {e}")
|
| 153 |
+
raise
|
| 154 |
+
|
| 155 |
+
metadata=self.extract_metadata(pdf_path)
|
| 156 |
+
|
| 157 |
+
raw_text = self.extract_text_from_pdf(pdf_path)
|
| 158 |
+
clean_text = self.clean_text(raw_text)
|
| 159 |
+
chunks = self.chunk_text(clean_text, metadata)
|
| 160 |
+
logger.info(f"Processed {pdf_path}: {len(chunks)} chunks created")
|
| 161 |
+
return chunks
|
| 162 |
+
|
| 163 |
+
def process_documents(self, pdf_paths: List[str]) -> List[Dict]:
|
| 164 |
+
documents = []
|
| 165 |
+
for path in pdf_paths:
|
| 166 |
+
documents.extend(self.process_document(path))
|
| 167 |
+
return documents
|
src/vector_store.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faiss
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from langchain_openai import OpenAIEmbeddings
|
| 6 |
+
from threading import Lock
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
logger.addHandler(logging.NullHandler())
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DocumentChunk = Dict[str, Any]
|
| 15 |
+
|
| 16 |
+
class FAISSVectorStore:
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dimension: int = 3072,
|
| 20 |
+
index_path: str = "data/faiss_index",
|
| 21 |
+
embedding_model: str = "text-embedding-3-large", #3072-dim vectors
|
| 22 |
+
):
|
| 23 |
+
if OpenAIEmbeddings is None:
|
| 24 |
+
raise ImportError(
|
| 25 |
+
"Could not import OpenAIEmbeddings from langchain. "
|
| 26 |
+
"Install langchain or adapt the import to your environment."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.dimension = dimension
|
| 30 |
+
self.index_path = Path(index_path)
|
| 31 |
+
self._lock = Lock()
|
| 32 |
+
self.index_path.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# Instantiate embeddings (may make API calls later when embedding)
|
| 35 |
+
self.embeddings = OpenAIEmbeddings(model=embedding_model)
|
| 36 |
+
|
| 37 |
+
# in-memory structures
|
| 38 |
+
self.documents: List[DocumentChunk] = []
|
| 39 |
+
|
| 40 |
+
# Create a new FAISS index (will be replaced by load if a saved index exists)
|
| 41 |
+
self.index = faiss.IndexFlatIP(self.dimension) # All vectors must be this length
|
| 42 |
+
|
| 43 |
+
# If there's a saved index, load it (overwrites the index created above).
|
| 44 |
+
self.load_index() # safe: will return False if nothing to load
|
| 45 |
+
|
| 46 |
+
def _ensure_index_dim(self, d: int):
|
| 47 |
+
"""Ensure FAISS index has dimension d."""
|
| 48 |
+
# If current index has no vectors, and d != self.dimension, recreate.
|
| 49 |
+
# Using getattr for defensive programming
|
| 50 |
+
if getattr(self.index, "ntotal", 0) == 0 and getattr(self.index, "d", None) != d:
|
| 51 |
+
logger.info("Recreating an empty index with dimension %d", d)
|
| 52 |
+
self.dimension = d
|
| 53 |
+
self.index = faiss.IndexFlatIP(self.dimension)
|
| 54 |
+
elif getattr(self.index, "d", None) is not None and self.index.d != d:
|
| 55 |
+
raise ValueError(f"Embedding dimension ({d}) does not match existing index dimension ({self.index.d}).")
|
| 56 |
+
|
| 57 |
+
def add_documents(self, chunks: List[DocumentChunk], save: bool = True):
|
| 58 |
+
"""
|
| 59 |
+
Add list of chunks to the FAISS index. Each chunk MUST contain 'text'.
|
| 60 |
+
If index is empty and embedding dimension differs, the index will be re-created.
|
| 61 |
+
"""
|
| 62 |
+
with self._lock:
|
| 63 |
+
if not chunks:
|
| 64 |
+
logger.debug("No chunks to add.")
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
texts = []
|
| 68 |
+
for i, chunk in enumerate(chunks):
|
| 69 |
+
if not isinstance(chunk, dict):
|
| 70 |
+
raise ValueError(f"Chunk {i} is not a dictionary")
|
| 71 |
+
if "text" not in chunk:
|
| 72 |
+
raise ValueError(f"Chunk {i} missing required 'text' field")
|
| 73 |
+
if not isinstance(chunk["text"], str):
|
| 74 |
+
raise ValueError(f"Chunk {i} 'text' field must be a string")
|
| 75 |
+
if not chunk["text"].strip():
|
| 76 |
+
logger.warning(f"Chunk {i} has empty text content")
|
| 77 |
+
continue
|
| 78 |
+
texts.append(chunk["text"])
|
| 79 |
+
|
| 80 |
+
# Get embeddings from the embedding provider (call to a model)
|
| 81 |
+
embeddings = self.embeddings.embed_documents(texts)
|
| 82 |
+
embeddings_np = np.asarray(embeddings, dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
# Embedding shape checks
|
| 85 |
+
if embeddings_np.ndim == 1:
|
| 86 |
+
# single vector returned as 1D array -> reshape to (1, d)
|
| 87 |
+
embeddings_np = embeddings_np.reshape(1, -1)
|
| 88 |
+
|
| 89 |
+
emb_d = embeddings_np.shape[1]
|
| 90 |
+
# If needed, recreate the index dimension (only possible if index currently empty)
|
| 91 |
+
self._ensure_index_dim(emb_d)
|
| 92 |
+
|
| 93 |
+
if emb_d != self.index.d:
|
| 94 |
+
raise ValueError(f"Embedding dim {emb_d} != index dim {self.index.d}")
|
| 95 |
+
|
| 96 |
+
# L2-normalize rows (in place) so inner product == cosine similarity
|
| 97 |
+
faiss.normalize_L2(embeddings_np)
|
| 98 |
+
|
| 99 |
+
# Add to index
|
| 100 |
+
self.index.add(embeddings_np)
|
| 101 |
+
# The documentation of "add" suggests we have to put the number of vectors,
|
| 102 |
+
# as a first argument, but Python does it for us.
|
| 103 |
+
|
| 104 |
+
# Append documents (simple positional mapping: index position -> documents list)
|
| 105 |
+
self.documents.extend(chunks)
|
| 106 |
+
|
| 107 |
+
if save:
|
| 108 |
+
self.save_index()
|
| 109 |
+
|
| 110 |
+
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
| 111 |
+
"""
|
| 112 |
+
Search similar documents for `query`. Returns up to k results.
|
| 113 |
+
Each result: { "content": <text>, "metadata": <metadata>, "similarity_score": <float> }
|
| 114 |
+
similarity_score is the inner product of normalized vectors => cosine similarity in [-1,1].
|
| 115 |
+
"""
|
| 116 |
+
with self._lock:
|
| 117 |
+
# guard: no vectors at all
|
| 118 |
+
if getattr(self.index, "ntotal", 0) == 0:
|
| 119 |
+
logger.debug("Search called but index is empty.")
|
| 120 |
+
return []
|
| 121 |
+
|
| 122 |
+
# embed query
|
| 123 |
+
q_emb = self.embeddings.embed_query(query)
|
| 124 |
+
q_np = np.asarray([q_emb], dtype=np.float32)
|
| 125 |
+
if q_np.ndim == 1:
|
| 126 |
+
q_np = q_np.reshape(1, -1)
|
| 127 |
+
|
| 128 |
+
if q_np.shape[1] != self.index.d:
|
| 129 |
+
# if index is empty we could recreate; but at this point we know index has vectors.
|
| 130 |
+
raise ValueError(f"Query embedding dim {q_np.shape[1]} does not match index dimension {self.index.d}")
|
| 131 |
+
|
| 132 |
+
faiss.normalize_L2(q_np)
|
| 133 |
+
|
| 134 |
+
# clamp k
|
| 135 |
+
k = min(k, int(self.index.ntotal))
|
| 136 |
+
|
| 137 |
+
distances, indices = self.index.search(q_np, k) # distances shape (1,k) ; indices shape (1,k)
|
| 138 |
+
|
| 139 |
+
results = []
|
| 140 |
+
for score, idx in zip(distances[0], indices[0]):
|
| 141 |
+
if idx < 0:
|
| 142 |
+
# FAISS returns -1 for "empty" slots sometimes; skip
|
| 143 |
+
continue
|
| 144 |
+
if idx >= len(self.documents):
|
| 145 |
+
logger.warning("Index returned idx %d but documents list has length %d", idx, len(self.documents))
|
| 146 |
+
continue
|
| 147 |
+
doc = self.documents[idx]
|
| 148 |
+
results.append({
|
| 149 |
+
"content": doc.get("text"),
|
| 150 |
+
"metadata": doc.get("metadata", {}),
|
| 151 |
+
"similarity_score": float(score) # already cosine because of normalization
|
| 152 |
+
})
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
def save_index(self):
|
| 156 |
+
"""Persist index and documents to disk."""
|
| 157 |
+
self.index_path.mkdir(parents=True, exist_ok=True)
|
| 158 |
+
faiss.write_index(self.index, str(self.index_path / "index.faiss"))
|
| 159 |
+
with open(self.index_path / "documents.pkl", "wb") as f:
|
| 160 |
+
pickle.dump(self.documents, f)
|
| 161 |
+
logger.debug("FAISS index and documents saved to %s", self.index_path)
|
| 162 |
+
|
| 163 |
+
def load_index(self) -> bool:
|
| 164 |
+
"""Load index and documents from disk. Returns True if loaded."""
|
| 165 |
+
index_file = self.index_path / "index.faiss"
|
| 166 |
+
docs_file = self.index_path / "documents.pkl"
|
| 167 |
+
|
| 168 |
+
if index_file.exists() and docs_file.exists():
|
| 169 |
+
self.index = faiss.read_index(str(index_file))
|
| 170 |
+
with open(docs_file, "rb") as f:
|
| 171 |
+
self.documents = pickle.load(f)
|
| 172 |
+
|
| 173 |
+
# update dimension to match loaded index
|
| 174 |
+
if getattr(self.index, "d", None) is not None:
|
| 175 |
+
self.dimension = int(self.index.d)
|
| 176 |
+
|
| 177 |
+
if self.index.d == 0 or len(self.documents) != self.index.ntotal:
|
| 178 |
+
logger.error("Corrupted index detected, deleting...")
|
| 179 |
+
index_file.unlink()
|
| 180 |
+
docs_file.unlink()
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
# warn if counts differ
|
| 184 |
+
if len(self.documents) != self.index.ntotal:
|
| 185 |
+
logger.warning(
|
| 186 |
+
"Loaded documents list length (%d) differs from index.ntotal (%d). "
|
| 187 |
+
"This can lead to mismatches. Using what's available.",
|
| 188 |
+
len(self.documents),
|
| 189 |
+
self.index.ntotal,
|
| 190 |
+
)
|
| 191 |
+
logger.info("Loaded FAISS index from %s (ntotal=%d, dim=%d)",
|
| 192 |
+
index_file, int(self.index.ntotal), int(self.index.d))
|
| 193 |
+
return True
|
| 194 |
+
return False
|
| 195 |
+
|