SudeshKK5 commited on
Commit
416a463
·
verified ·
1 Parent(s): 4ce961c

Create app.py

Browse files

Created chatbot with RAG model

Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PyPDF2
3
+ import faiss
4
+ import torch
5
+ import streamlit as st
6
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ # Load embedding model
10
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L3-v2")
11
+
12
+ # Load a powerful LLM (e.g., Mistral-7B, GPT-4 API, T5-based model)
13
+ llm_model_name = "google/flan-t5-small"
14
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
15
+ #llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
16
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name, torch_dtype=torch.float16)
17
+
18
+
19
+ # Function to extract text from PDF
20
+ def extract_text_from_pdf(pdf_path):
21
+ """Extract text from a research paper (PDF)."""
22
+ text = ""
23
+ with open(pdf_path, "rb") as f:
24
+ reader = PyPDF2.PdfReader(f)
25
+ for page in reader.pages:
26
+ text += page.extract_text() + "\n"
27
+ return text
28
+
29
+ # Function to chunk text into sections
30
+ def chunk_text(text, chunk_size=300):
31
+ """Splits text into sections based on paper structure."""
32
+ sections = {}
33
+ split_text = text.split("\n")
34
+ #chunk_size = min(chunk_size, len(split_text))
35
+ current_section = "Other"
36
+ sections[current_section] = []
37
+
38
+ for line in split_text:
39
+ line = line.strip()
40
+ if line.lower().startswith("abstract"):
41
+ current_section = "Abstract"
42
+ sections[current_section] = []
43
+ elif line.lower().startswith("introduction"):
44
+ current_section = "Introduction"
45
+ sections[current_section] = []
46
+ elif line.lower().startswith("conclusion"):
47
+ current_section = "Conclusion"
48
+ sections[current_section] = []
49
+
50
+ sections[current_section].append(line)
51
+
52
+ # Convert sections to chunks
53
+ for section in sections:
54
+ sections[section] = " ".join(sections[section])
55
+
56
+ return sections
57
+
58
+ # Function to create FAISS vector database
59
+ def build_vector_database(sections):
60
+ """Builds FAISS vector index for research paper sections."""
61
+ chunk_texts = list(sections.values())
62
+ embeddings = embedding_model.encode(chunk_texts)
63
+ dim = embeddings.shape[1]
64
+ index = faiss.IndexFlatL2(dim)
65
+ index.add(embeddings)
66
+ return index, embeddings, list(sections.keys()), chunk_texts
67
+
68
+ # Function to retrieve relevant context
69
+ def retrieve_context(query, index, embeddings, section_titles, section_texts, top_k=1):
70
+ """Retrieves most relevant sections for a query."""
71
+ query_embedding = embedding_model.encode([query])
72
+ embeddings = torch.tensor(embeddings)
73
+ distances, indices = index.search(query_embedding, top_k)
74
+ retrieved_contexts = [f"**{section_titles[idx]}**: {section_texts[idx]}" for idx in indices[0]]
75
+ return "\n".join(retrieved_contexts)
76
+
77
+
78
+
79
+ # Function to generate a concise answer
80
+ def generate_answer_rag(question, context, max_length=512):
81
+ """Truncate input text to prevent exceeding model token limit."""
82
+ input_text = f"Question: {question}\nContext: {context[:max_length]}"
83
+ input_ids = llm_tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids
84
+ output_ids = llm_model.generate(input_ids, max_length=150)
85
+ return llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
86
+
87
+
88
+
89
+
90
+ # Streamlit UI
91
+ def main():
92
+ st.title("AI Research Paper RAG Chatbot")
93
+
94
+ uploaded_pdf = st.file_uploader("Upload a Research Paper (PDF)", type=["pdf"])
95
+ if uploaded_pdf is not None:
96
+ pdf_path = "uploaded_paper.pdf"
97
+ with open(pdf_path, "wb") as f:
98
+ f.write(uploaded_pdf.read())
99
+ st.success("PDF uploaded successfully!")
100
+
101
+ # Extract and preprocess text
102
+ text = extract_text_from_pdf(pdf_path)
103
+ text_sections = chunk_text(text)
104
+
105
+ # Build FAISS vector database
106
+ index, embeddings, section_titles, section_texts = build_vector_database(text_sections)
107
+ st.write(f"Paper processed into {len(text_sections)} sections for efficient retrieval.")
108
+
109
+ # User query input
110
+ user_question = st.text_input("Ask a question about the paper:")
111
+ if user_question:
112
+ context = retrieve_context(user_question, index, embeddings, section_titles, section_texts)
113
+ answer = generate_answer_rag(user_question, context)
114
+
115
+ st.write(f"**Retrieved Context:**\n{context}")
116
+ st.write(f"**Generated Answer:**\n{answer}")
117
+
118
+ if __name__ == "__main__":
119
+ main()