document-qa-dev / streamlit_app.py
lfoppiano's picture
Upload folder using huggingface_hub
6f06d5d verified
Raw
History Blame Contribute Delete
19.2 kB
"""Streamlit frontend for the Document Q/A system.
This module implements the web UI for uploading scientific PDFs,
asking questions via an LLM-powered RAG pipeline, and viewing
highlighted PDF passages. It is the main entry-point when running::
streamlit run streamlit_app.py
Configuration is loaded from environment variables (see ``.env.example``).
"""
import os
import re
from hashlib import blake2b
from tempfile import NamedTemporaryFile
import dotenv
import streamlit as st
from grobid_quantities.quantities import QuantitiesAPI
from langchain.memory import ConversationBufferMemory
from langchain_openai import ChatOpenAI
from streamlit_pdf_viewer import pdf_viewer
from document_qa.custom_embeddings import ModalEmbeddings
from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations, GrobidServiceError
from document_qa.ner_client_generic import NERClientGeneric
dotenv.load_dotenv(override=True)
API_MODELS = {"microsoft/Phi-4-mini-instruct": os.environ["PHI_URL"], "Qwen/Qwen3-0.6B": os.environ["QWEN_URL"]}
API_EMBEDDINGS = {"intfloat/multilingual-e5-large-instruct-modal": os.environ["EMBEDS_URL"]}
if "rqa" not in st.session_state:
st.session_state["rqa"] = {}
if "model" not in st.session_state:
st.session_state["model"] = None
if "api_keys" not in st.session_state:
st.session_state["api_keys"] = {}
if "doc_id" not in st.session_state:
st.session_state["doc_id"] = None
if "loaded_embeddings" not in st.session_state:
st.session_state["loaded_embeddings"] = None
if "hash" not in st.session_state:
st.session_state["hash"] = None
if "git_rev" not in st.session_state:
st.session_state["git_rev"] = "unknown"
if os.path.exists("revision.txt"):
with open("revision.txt", "r") as fr:
from_file = fr.read()
st.session_state["git_rev"] = from_file if len(from_file) > 0 else "unknown"
if "messages" not in st.session_state:
st.session_state.messages = []
if "ner_processing" not in st.session_state:
st.session_state["ner_processing"] = False
if "uploaded" not in st.session_state:
st.session_state["uploaded"] = False
if "memory" not in st.session_state:
st.session_state["memory"] = None
if "binary" not in st.session_state:
st.session_state["binary"] = None
if "annotations" not in st.session_state:
st.session_state["annotations"] = None
if "should_show_annotations" not in st.session_state:
st.session_state["should_show_annotations"] = True
if "pdf" not in st.session_state:
st.session_state["pdf"] = None
if "embeddings" not in st.session_state:
st.session_state["embeddings"] = None
if "scroll_to_first_annotation" not in st.session_state:
st.session_state["scroll_to_first_annotation"] = False
st.set_page_config(
page_title="Scientific Document Insights Q/A",
page_icon="📝",
initial_sidebar_state="expanded",
layout="wide",
menu_items={
"Get Help": "https://github.com/lfoppiano/document-qa",
"Report a bug": "https://github.com/lfoppiano/document-qa/issues",
"About": "Upload a scientific article in PDF, ask questions, get insights.",
},
)
st.markdown(
"""
<style>
.block-container {
padding-top: 3rem;
padding-bottom: 1rem;
padding-left: 1rem;
padding-right: 1rem;
}
</style>
""",
unsafe_allow_html=True,
)
def new_file():
"""Reset session state when a new file is uploaded.
Clears previous embeddings, annotations, and conversation memory
so the pipeline starts fresh for the new document.
"""
st.session_state["loaded_embeddings"] = None
st.session_state["doc_id"] = None
st.session_state["uploaded"] = True
st.session_state["annotations"] = []
if st.session_state["memory"]:
st.session_state["memory"].clear()
def clear_memory():
"""Clear the conversation buffer memory (chat history)."""
st.session_state["memory"].clear()
# @st.cache_resource
def init_qa(model_name, embeddings_name):
"""Initialise the Q/A engine with the selected LLM and embedding models.
Args:
model_name: Key from ``API_MODELS`` selecting the LLM.
embeddings_name: Key from ``API_EMBEDDINGS`` selecting the
embedding model.
Returns:
DocumentQAEngine: Ready-to-use engine instance.
"""
st.session_state["memory"] = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
chat = ChatOpenAI(model=model_name, temperature=0.0, base_url=API_MODELS[model_name], api_key=os.environ.get("API_KEY"))
embeddings = ModalEmbeddings(
url=API_EMBEDDINGS[embeddings_name], model_name=embeddings_name, api_key=os.environ.get("EMBEDS_API_KEY")
)
storage = DataStorage(embeddings)
return DocumentQAEngine(
chat, storage, grobid_url=os.environ["GROBID_URL"], memory=st.session_state["memory"], ping_grobid_server=False
)
@st.cache_resource
def init_ner():
"""Initialise the NER aggregation processor (quantities + materials).
Uses ``GROBID_QUANTITIES_URL`` and ``GROBID_MATERIALS_URL`` from
the environment. Results are cached across Streamlit reruns.
Returns:
GrobidAggregationProcessor: Configured processor instance.
"""
quantities_client = QuantitiesAPI(os.environ["GROBID_QUANTITIES_URL"], check_server=True)
materials_client = NERClientGeneric(ping=True)
config_materials = {
"grobid": {
"server": os.environ["GROBID_MATERIALS_URL"],
"sleep_time": 5,
"timeout": 60,
"url_mapping": {
"processText_disable_linking": "/service/process/text?disableLinking=True",
# 'processText_disable_linking': "/service/process/text"
},
}
}
materials_client.set_config(config_materials)
gqa = GrobidAggregationProcessor(
grobid_quantities_client=quantities_client, grobid_superconductors_client=materials_client
)
return gqa
gqa = init_ner()
def get_file_hash(fname):
"""Compute a BLAKE2b hex digest for the file at *fname*.
Used to generate deterministic document IDs from file content.
"""
hash_md5 = blake2b()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def play_old_messages(container):
"""Re-render previous chat messages into *container*.
Called on Streamlit reruns to restore the visible conversation
history from ``st.session_state['messages']``.
"""
if st.session_state["messages"]:
for message in st.session_state["messages"]:
if message["role"] == "user":
container.chat_message("user").markdown(message["content"])
elif message["role"] == "assistant":
if mode == "LLM":
container.chat_message("assistant").markdown(message["content"], unsafe_allow_html=True)
else:
container.chat_message("assistant").write(message["content"])
# is_api_key_provided = st.session_state['api_key']
with st.sidebar:
st.title("📝 Document Q/A")
st.markdown("Upload a scientific article in PDF, ask questions, get insights.")
st.markdown(
":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: "
)
st.markdown("LM and Embeddings are powered by [Modal.com](https://modal.com/)")
st.divider()
st.session_state["model"] = model = st.selectbox(
"Model:",
options=API_MODELS.keys(),
index=(list(API_MODELS.keys())).index(os.environ["DEFAULT_MODEL"])
if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"]
else 0,
placeholder="Select model",
help="Select the LLM model:",
disabled=st.session_state["doc_id"] is not None or st.session_state["uploaded"],
)
st.session_state["embeddings"] = embedding_name = st.selectbox(
"Embeddings:",
options=API_EMBEDDINGS.keys(),
index=(list(API_EMBEDDINGS.keys())).index(os.environ["DEFAULT_EMBEDDING"])
if "DEFAULT_EMBEDDING" in os.environ and os.environ["DEFAULT_EMBEDDING"]
else 0,
placeholder="Select embedding",
help="Select the Embedding function:",
disabled=st.session_state["doc_id"] is not None or st.session_state["uploaded"],
)
api_key = os.environ["API_KEY"]
if model not in st.session_state["rqa"] or model not in st.session_state["api_keys"]:
with st.spinner("Preparing environment"):
st.session_state["rqa"][model] = init_qa(model, st.session_state["embeddings"])
st.session_state["api_keys"][model] = api_key
left_column, right_column = st.columns([5, 4])
right_column = right_column.container(border=True)
left_column = left_column.container(border=True)
with right_column:
uploaded_file = st.file_uploader(
"Upload a scientific article",
type=("pdf"),
on_change=new_file,
disabled=st.session_state["model"] is not None and st.session_state["model"] not in st.session_state["api_keys"],
help="The full-text is extracted using Grobid.",
)
placeholder = st.empty()
messages = st.container(height=300)
question = st.chat_input(
"Ask something about the article",
# placeholder="Can you give me a short summary?",
disabled=not uploaded_file,
)
query_modes = {"llm": "LLM Q/A", "embeddings": "Embeddings", "question_coefficient": "Question coefficient"}
with st.sidebar:
st.header("Settings")
mode = st.radio(
"Query mode",
("llm", "embeddings", "question_coefficient"),
disabled=not uploaded_file,
index=0,
horizontal=True,
format_func=lambda x: query_modes[x],
help="LLM will respond the question, Embedding will show the "
"relevant paragraphs to the question in the paper. "
"Question coefficient attempt to estimate how effective the question will be answered.",
)
st.session_state["scroll_to_first_annotation"] = st.checkbox(
"Scroll to context", help="The PDF viewer will automatically scroll to the first relevant passage in the document."
)
st.session_state["ner_processing"] = st.checkbox(
"Identify materials and properties.",
help="The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.",
)
# Add a checkbox for showing annotations
# st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
# st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True)
chunk_size = st.slider(
"Text chunks size",
-1,
2000,
value=-1,
help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.",
disabled=uploaded_file is not None,
)
if chunk_size == -1:
context_size = st.slider(
"Context size (paragraphs)",
3,
20,
value=10,
help="Number of paragraphs to consider when answering a question",
disabled=not uploaded_file,
)
else:
context_size = st.slider(
"Context size (chunks)",
3,
10,
value=4,
help="Number of chunks to consider when answering a question",
disabled=not uploaded_file,
)
st.divider()
st.header("Documentation")
st.markdown("https://github.com/lfoppiano/document-qa")
st.markdown(
"""Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions."""
)
if st.session_state["git_rev"] != "unknown":
st.markdown(
"**Revision number**: ["
+ st.session_state["git_rev"]
+ "](https://github.com/lfoppiano/document-qa/commit/"
+ st.session_state["git_rev"]
+ ")"
)
if uploaded_file and not st.session_state.loaded_embeddings:
if model not in st.session_state["api_keys"]:
st.error("Before uploading a document, you must enter the API key. ")
st.stop()
with left_column:
try:
with st.spinner("Reading file, calling Grobid, and creating in-memory embeddings..."):
binary = uploaded_file.getvalue()
tmp_path = None
try:
with NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
tmp_file.write(bytearray(binary))
tmp_file.flush()
tmp_path = tmp_file.name
st.session_state["binary"] = binary
st.session_state["doc_id"] = st.session_state["rqa"][model].create_memory_embeddings(
tmp_path, chunk_size=chunk_size, perc_overlap=0.1
)
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
st.session_state["loaded_embeddings"] = True
st.session_state.messages = []
except GrobidServiceError as exc:
st.session_state["doc_id"] = None
st.session_state["loaded_embeddings"] = False
st.session_state["uploaded"] = False
message = str(exc).strip() or "Grobid is not responding."
if not message.endswith((".", "!", "?")):
message += "."
st.error(f"{message} Please try again later.")
st.stop()
def rgb_to_hex(rgb):
"""Convert an ``(R, G, B)`` tuple to a ``#rrggbb`` hex string."""
return "#{:02x}{:02x}{:02x}".format(*rgb)
def generate_color_gradient(num_elements):
"""Generate a warm-to-cold hex colour gradient for annotation ranking.
The first colour (most relevant passage) is orange; the last (least
relevant) is blue. Intermediate colours are linearly interpolated.
Args:
num_elements: Number of gradient stops to produce.
Returns:
list[str]: Hex colour strings, e.g. ``['#ffa500', …, '#0000ff']``.
"""
# Define warm and cold colors in RGB format
warm_color = (255, 165, 0) # Orange
cold_color = (0, 0, 255) # Blue
# Generate a linear gradient of colors
color_gradient = [
rgb_to_hex(
tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in zip(warm_color, cold_color))
)
for i in range(num_elements)
]
return color_gradient
with right_column:
if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
for message in st.session_state.messages:
# with messages.chat_message(message["role"]):
if message["mode"] == "llm":
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
elif message["mode"] == "embeddings":
messages.chat_message(message["role"]).write(message["content"])
elif message["mode"] == "question_coefficient":
messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
if model not in st.session_state["rqa"]:
st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
st.stop()
text_response = None
if mode == "embeddings":
with placeholder:
with st.spinner("Fetching the relevant context..."):
text_response, coordinates = st.session_state["rqa"][model].query_storage(
question, st.session_state.doc_id, context_size=context_size
)
elif mode == "llm":
with placeholder:
with st.spinner("Generating LLM response..."):
_, text_response, coordinates = st.session_state["rqa"][model].query_document(
question, st.session_state.doc_id, context_size=context_size
)
elif mode == "question_coefficient":
with st.spinner("Estimate question/context relevancy..."):
text_response, coordinates = st.session_state["rqa"][model].analyse_query(
question, st.session_state.doc_id, context_size=context_size
)
annotations = [
[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
for coord_doc in coordinates
]
gradients = generate_color_gradient(len(annotations))
for i, color in enumerate(gradients):
for annotation in annotations[i]:
annotation["color"] = color
if i == 0:
annotation["border"] = "dotted"
st.session_state["annotations"] = [annotation for annotation_doc in annotations for annotation in annotation_doc]
if not text_response:
st.error("Something went wrong. Contact info AT sciencialab.com to report the issue through GitHub.")
if mode == "llm":
if st.session_state["ner_processing"]:
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
text_response = decorated_text
messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True)
else:
messages.chat_message("assistant").write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages(messages)
with left_column:
if st.session_state["binary"]:
with st.container(height=600):
pdf_viewer(
input=st.session_state["binary"],
annotation_outline_size=2,
annotations=st.session_state["annotations"] if st.session_state["annotations"] else [],
render_text=True,
scroll_to_annotation=1
if (st.session_state["annotations"] and st.session_state["scroll_to_first_annotation"])
else None,
)