feat: add citations showup
Browse files- .gitattributes +1 -0
- .gitignore +17 -0
- src/core/config.py +2 -1
- src/gradio_app.py +43 -10
- src/prompts.py +10 -2
- src/rag/pipeline.py +21 -5
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
vector_store/index.faiss filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Env
|
| 10 |
+
.venv
|
| 11 |
+
.env
|
| 12 |
+
|
| 13 |
+
# UV lock
|
| 14 |
+
uv.lock
|
| 15 |
+
|
| 16 |
+
# Store
|
| 17 |
+
/vector_store
|
src/core/config.py
CHANGED
|
@@ -21,8 +21,9 @@ class Settings(BaseSettings):
|
|
| 21 |
def __init__(self, **kwargs):
|
| 22 |
super().__init__(**kwargs)
|
| 23 |
# Try to get token from environment if not set
|
|
|
|
| 24 |
if not self.HUGGINGFACE_TOKEN:
|
| 25 |
-
self.HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "")
|
| 26 |
|
| 27 |
|
| 28 |
settings = Settings()
|
|
|
|
| 21 |
def __init__(self, **kwargs):
|
| 22 |
super().__init__(**kwargs)
|
| 23 |
# Try to get token from environment if not set
|
| 24 |
+
# HuggingFace Spaces uses HF_TOKEN by default
|
| 25 |
if not self.HUGGINGFACE_TOKEN:
|
| 26 |
+
self.HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGINGFACE_TOKEN", ""))
|
| 27 |
|
| 28 |
|
| 29 |
settings = Settings()
|
src/gradio_app.py
CHANGED
|
@@ -2,19 +2,52 @@ import gradio as gr
|
|
| 2 |
from src.rag.pipeline import answer_question
|
| 3 |
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def run_gradio():
|
| 6 |
-
with gr.Blocks(title="PaperMate") as demo:
|
| 7 |
-
gr.Markdown(
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
max_lines=30,
|
| 13 |
)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
btn =
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
demo.launch()
|
| 20 |
|
|
|
|
| 2 |
from src.rag.pipeline import answer_question
|
| 3 |
|
| 4 |
|
| 5 |
+
def format_answer(question: str) -> tuple[str, str]:
|
| 6 |
+
"""Format answer with citations."""
|
| 7 |
+
answer, citations = answer_question(question)
|
| 8 |
+
|
| 9 |
+
# Format citations
|
| 10 |
+
if citations:
|
| 11 |
+
citations_text = "\n\n### π Sources:\n"
|
| 12 |
+
for i, citation in enumerate(citations, 1):
|
| 13 |
+
citations_text += f"{i}. **{citation['title']}** ({citation['year']})\n"
|
| 14 |
+
else:
|
| 15 |
+
citations_text = ""
|
| 16 |
+
|
| 17 |
+
return answer, citations_text
|
| 18 |
+
|
| 19 |
+
|
| 20 |
def run_gradio():
|
| 21 |
+
with gr.Blocks(title="PaperMate", theme=gr.themes.Soft()) as demo:
|
| 22 |
+
gr.Markdown(
|
| 23 |
+
"""
|
| 24 |
+
# π PaperMate β Research Paper Q&A Assistant
|
| 25 |
+
Ask questions about research papers and get answers backed by scientific literature.
|
| 26 |
+
"""
|
|
|
|
| 27 |
)
|
| 28 |
+
|
| 29 |
+
with gr.Row():
|
| 30 |
+
with gr.Column():
|
| 31 |
+
question = gr.Textbox(
|
| 32 |
+
label="Your Question",
|
| 33 |
+
placeholder="e.g., What techniques are used to handle out-of-vocabulary words in NLP?",
|
| 34 |
+
lines=3,
|
| 35 |
+
)
|
| 36 |
+
btn = gr.Button("π Search & Answer", variant="primary")
|
| 37 |
+
|
| 38 |
+
with gr.Row():
|
| 39 |
+
with gr.Column():
|
| 40 |
+
output = gr.Textbox(label="Answer", lines=10, max_lines=20)
|
| 41 |
+
citations = gr.Markdown(label="Sources")
|
| 42 |
|
| 43 |
+
btn.click(fn=format_answer, inputs=question, outputs=[output, citations])
|
| 44 |
+
|
| 45 |
+
gr.Markdown(
|
| 46 |
+
"""
|
| 47 |
+
---
|
| 48 |
+
π‘ **Tip:** Questions are answered using relevant papers from the ArXiv dataset.
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
|
| 52 |
demo.launch()
|
| 53 |
|
src/prompts.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
| 1 |
SYSTEM_PROMPT = """
|
| 2 |
-
You are a
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
Context:
|
| 6 |
{context}
|
| 7 |
|
| 8 |
User Question:
|
| 9 |
{question}
|
|
|
|
|
|
|
| 10 |
"""
|
|
|
|
| 1 |
SYSTEM_PROMPT = """
|
| 2 |
+
You are a knowledgeable research assistant that provides accurate, well-structured answers about scientific papers.
|
| 3 |
+
|
| 4 |
+
Guidelines:
|
| 5 |
+
- Base your answer ONLY on the information provided in the context below
|
| 6 |
+
- Provide clear, concise, and informative responses
|
| 7 |
+
- When relevant, mention specific findings or methodologies from the papers
|
| 8 |
+
- Do NOT make up information or reference papers not in the context
|
| 9 |
+
- Do NOT mention "the context" or "the provided papers" in your response - answer naturally
|
| 10 |
|
| 11 |
Context:
|
| 12 |
{context}
|
| 13 |
|
| 14 |
User Question:
|
| 15 |
{question}
|
| 16 |
+
|
| 17 |
+
Provide a comprehensive answer based on the information above:
|
| 18 |
"""
|
src/rag/pipeline.py
CHANGED
|
@@ -21,8 +21,23 @@ def create_context(docs) -> str:
|
|
| 21 |
return context
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Retrieve more documents than needed for reranking
|
| 27 |
try:
|
| 28 |
retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
|
|
@@ -34,12 +49,13 @@ def answer_question(question: str) -> str:
|
|
| 34 |
)
|
| 35 |
|
| 36 |
context = create_context(reranked_results)
|
|
|
|
| 37 |
logging.info(f"Constructed context for LLM: {context}")
|
| 38 |
chain = get_chain()
|
| 39 |
|
| 40 |
response = chain.invoke({"context": context, "question": question})
|
|
|
|
| 41 |
except Exception as e:
|
| 42 |
logging.error(f"Error occurred while answering question: {e}")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return response
|
|
|
|
| 21 |
return context
|
| 22 |
|
| 23 |
|
| 24 |
+
def extract_citations(docs) -> list[dict]:
|
| 25 |
+
"""Extract paper titles and years for citations."""
|
| 26 |
+
citations = []
|
| 27 |
+
for doc in docs:
|
| 28 |
+
citations.append({
|
| 29 |
+
"title": doc.metadata.get('Titles', 'No Title'),
|
| 30 |
+
"year": doc.metadata.get('Years', 'Unknown')
|
| 31 |
+
})
|
| 32 |
+
return citations
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def answer_question(question: str) -> tuple[str, list[dict]]:
|
| 36 |
+
"""Answer a question using retrieved and reranked documents.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
tuple: (answer, citations) where citations is a list of dicts with 'title' and 'year'
|
| 40 |
+
"""
|
| 41 |
# Retrieve more documents than needed for reranking
|
| 42 |
try:
|
| 43 |
retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
context = create_context(reranked_results)
|
| 52 |
+
citations = extract_citations(reranked_results)
|
| 53 |
logging.info(f"Constructed context for LLM: {context}")
|
| 54 |
chain = get_chain()
|
| 55 |
|
| 56 |
response = chain.invoke({"context": context, "question": question})
|
| 57 |
+
return response, citations
|
| 58 |
except Exception as e:
|
| 59 |
logging.error(f"Error occurred while answering question: {e}")
|
| 60 |
+
error_msg = f"Sorry, an error occurred while processing your request: {str(e)}"
|
| 61 |
+
return error_msg, []
|
|
|