mechark commited on
Commit
c4d8214
Β·
2 Parent(s): 1830858 420cb36

feat: add citations showup

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +17 -0
  3. src/core/config.py +2 -1
  4. src/gradio_app.py +43 -10
  5. src/prompts.py +10 -2
  6. src/rag/pipeline.py +21 -5
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vector_store/index.faiss filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Env
10
+ .venv
11
+ .env
12
+
13
+ # UV lock
14
+ uv.lock
15
+
16
+ # Store
17
+ /vector_store
src/core/config.py CHANGED
@@ -21,8 +21,9 @@ class Settings(BaseSettings):
21
  def __init__(self, **kwargs):
22
  super().__init__(**kwargs)
23
  # Try to get token from environment if not set
 
24
  if not self.HUGGINGFACE_TOKEN:
25
- self.HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "")
26
 
27
 
28
  settings = Settings()
 
21
  def __init__(self, **kwargs):
22
  super().__init__(**kwargs)
23
  # Try to get token from environment if not set
24
+ # HuggingFace Spaces uses HF_TOKEN by default
25
  if not self.HUGGINGFACE_TOKEN:
26
+ self.HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGINGFACE_TOKEN", ""))
27
 
28
 
29
  settings = Settings()
src/gradio_app.py CHANGED
@@ -2,19 +2,52 @@ import gradio as gr
2
  from src.rag.pipeline import answer_question
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def run_gradio():
6
- with gr.Blocks(title="PaperMate") as demo:
7
- gr.Markdown("# πŸ“„ PaperMate β€” Ask about research papers")
8
- question = gr.Textbox(
9
- label="Enter your question",
10
- placeholder="e.g. What is the NEMO paper about?",
11
- lines=15,
12
- max_lines=30,
13
  )
14
- output = gr.Textbox(label="Answer", lines=15, max_lines=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- btn = gr.Button("Ask")
17
- btn.click(fn=answer_question, inputs=question, outputs=output)
 
 
 
 
 
 
18
 
19
  demo.launch()
20
 
 
2
  from src.rag.pipeline import answer_question
3
 
4
 
5
+ def format_answer(question: str) -> tuple[str, str]:
6
+ """Format answer with citations."""
7
+ answer, citations = answer_question(question)
8
+
9
+ # Format citations
10
+ if citations:
11
+ citations_text = "\n\n### πŸ“š Sources:\n"
12
+ for i, citation in enumerate(citations, 1):
13
+ citations_text += f"{i}. **{citation['title']}** ({citation['year']})\n"
14
+ else:
15
+ citations_text = ""
16
+
17
+ return answer, citations_text
18
+
19
+
20
  def run_gradio():
21
+ with gr.Blocks(title="PaperMate", theme=gr.themes.Soft()) as demo:
22
+ gr.Markdown(
23
+ """
24
+ # πŸ“„ PaperMate β€” Research Paper Q&A Assistant
25
+ Ask questions about research papers and get answers backed by scientific literature.
26
+ """
 
27
  )
28
+
29
+ with gr.Row():
30
+ with gr.Column():
31
+ question = gr.Textbox(
32
+ label="Your Question",
33
+ placeholder="e.g., What techniques are used to handle out-of-vocabulary words in NLP?",
34
+ lines=3,
35
+ )
36
+ btn = gr.Button("πŸ” Search & Answer", variant="primary")
37
+
38
+ with gr.Row():
39
+ with gr.Column():
40
+ output = gr.Textbox(label="Answer", lines=10, max_lines=20)
41
+ citations = gr.Markdown(label="Sources")
42
 
43
+ btn.click(fn=format_answer, inputs=question, outputs=[output, citations])
44
+
45
+ gr.Markdown(
46
+ """
47
+ ---
48
+ πŸ’‘ **Tip:** Questions are answered using relevant papers from the ArXiv dataset.
49
+ """
50
+ )
51
 
52
  demo.launch()
53
 
src/prompts.py CHANGED
@@ -1,10 +1,18 @@
1
  SYSTEM_PROMPT = """
2
- You are a helpful assistant that provides accurate and concise information about scientific papers based on the given context.
3
- Do not provide any information that is not included in the context. Do not mention context details in your answer.
 
 
 
 
 
 
4
 
5
  Context:
6
  {context}
7
 
8
  User Question:
9
  {question}
 
 
10
  """
 
1
  SYSTEM_PROMPT = """
2
+ You are a knowledgeable research assistant that provides accurate, well-structured answers about scientific papers.
3
+
4
+ Guidelines:
5
+ - Base your answer ONLY on the information provided in the context below
6
+ - Provide clear, concise, and informative responses
7
+ - When relevant, mention specific findings or methodologies from the papers
8
+ - Do NOT make up information or reference papers not in the context
9
+ - Do NOT mention "the context" or "the provided papers" in your response - answer naturally
10
 
11
  Context:
12
  {context}
13
 
14
  User Question:
15
  {question}
16
+
17
+ Provide a comprehensive answer based on the information above:
18
  """
src/rag/pipeline.py CHANGED
@@ -21,8 +21,23 @@ def create_context(docs) -> str:
21
  return context
22
 
23
 
24
- def answer_question(question: str) -> str:
25
- """Answer a question using retrieved and reranked documents."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Retrieve more documents than needed for reranking
27
  try:
28
  retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
@@ -34,12 +49,13 @@ def answer_question(question: str) -> str:
34
  )
35
 
36
  context = create_context(reranked_results)
 
37
  logging.info(f"Constructed context for LLM: {context}")
38
  chain = get_chain()
39
 
40
  response = chain.invoke({"context": context, "question": question})
 
41
  except Exception as e:
42
  logging.error(f"Error occurred while answering question: {e}")
43
- response = "Sorry, exception occurred while processing your request. See logs for details."
44
-
45
- return response
 
21
  return context
22
 
23
 
24
+ def extract_citations(docs) -> list[dict]:
25
+ """Extract paper titles and years for citations."""
26
+ citations = []
27
+ for doc in docs:
28
+ citations.append({
29
+ "title": doc.metadata.get('Titles', 'No Title'),
30
+ "year": doc.metadata.get('Years', 'Unknown')
31
+ })
32
+ return citations
33
+
34
+
35
+ def answer_question(question: str) -> tuple[str, list[dict]]:
36
+ """Answer a question using retrieved and reranked documents.
37
+
38
+ Returns:
39
+ tuple: (answer, citations) where citations is a list of dicts with 'title' and 'year'
40
+ """
41
  # Retrieve more documents than needed for reranking
42
  try:
43
  retriever.k = settings.RETRIEVER_K_BEFORE_RERANK
 
49
  )
50
 
51
  context = create_context(reranked_results)
52
+ citations = extract_citations(reranked_results)
53
  logging.info(f"Constructed context for LLM: {context}")
54
  chain = get_chain()
55
 
56
  response = chain.invoke({"context": context, "question": question})
57
+ return response, citations
58
  except Exception as e:
59
  logging.error(f"Error occurred while answering question: {e}")
60
+ error_msg = f"Sorry, an error occurred while processing your request: {str(e)}"
61
+ return error_msg, []