tbaig1605 commited on
Commit
8b442b3
·
verified ·
1 Parent(s): 6e9dc92

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdfplumber
3
+ import docx
4
+ import pandas as pd
5
+ import numpy as np
6
+ import streamlit as st
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer
9
+ import faiss
10
+ from groq import Groq
11
+
12
+ # ==========================================================
13
+ # GROQ API KEY (use HF Secrets)
14
+ # ==========================================================
15
+ os.environ["GROQ_API_KEY"] = os.getenv("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")
16
+
17
+ # ==========================================================
18
+ # STREAMLIT UI
19
+ # ==========================================================
20
+ st.set_page_config(page_title="Universal RAG App", layout="wide")
21
+ st.title("📄 Universal Document RAG (PDF | Word | Excel)")
22
+
23
+ uploaded_file = st.file_uploader(
24
+ "Upload a document",
25
+ type=["pdf", "docx", "xlsx"]
26
+ )
27
+
28
+ # ==========================================================
29
+ # TEXT EXTRACTION FUNCTIONS (UNCHANGED)
30
+ # ==========================================================
31
+ def read_pdf_with_plumber(pdf_path):
32
+ pages = []
33
+ with pdfplumber.open(pdf_path) as pdf:
34
+ for i, page in enumerate(pdf.pages):
35
+ text = page.extract_text(x_tolerance=2)
36
+ if text:
37
+ pages.append({"page": i + 1, "text": text})
38
+ return pages
39
+
40
+ def read_word(doc_path):
41
+ doc = docx.Document(doc_path)
42
+ text = "\n\n".join([p.text for p in doc.paragraphs if p.text.strip() != ""])
43
+ return [{"page": 1, "text": text}]
44
+
45
+ def read_excel(xlsx_path):
46
+ df = pd.read_excel(xlsx_path, sheet_name=None)
47
+ texts = []
48
+ for sheet_name, sheet in df.items():
49
+ sheet_text = sheet.fillna("").astype(str).agg(" ".join, axis=1).str.cat(sep="\n")
50
+ texts.append({"page": sheet_name, "text": sheet_text})
51
+ return texts
52
+
53
+ # ==========================================================
54
+ # CORE RAG FUNCTIONS (UNCHANGED)
55
+ # ==========================================================
56
+ def chunk_text(pages, chunk_size=800):
57
+ chunks = []
58
+ for page in pages:
59
+ paragraphs = page["text"].split("\n\n")
60
+ buffer = ""
61
+ for para in paragraphs:
62
+ if len(buffer) + len(para) <= chunk_size:
63
+ buffer += " " + para
64
+ else:
65
+ chunks.append({"page": page["page"], "text": buffer.strip()})
66
+ buffer = para
67
+ if buffer:
68
+ chunks.append({"page": page["page"], "text": buffer.strip()})
69
+ return chunks
70
+
71
+ def tokenize_chunks(chunks, model_name="sentence-transformers/all-mpnet-base-v2"):
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
73
+ return [tokenizer(c["text"], truncation=True)["input_ids"] for c in chunks]
74
+
75
+ def create_embeddings(chunks, model_name="allenai/specter"):
76
+ embedder = SentenceTransformer(model_name)
77
+ texts = [c["text"] for c in chunks]
78
+ embeddings = embedder.encode(texts, show_progress_bar=False)
79
+ return embedder, np.array(embeddings)
80
+
81
+ def store_embeddings(embeddings):
82
+ faiss.normalize_L2(embeddings)
83
+ dim = embeddings.shape[1]
84
+ index = faiss.IndexFlatIP(dim)
85
+ index.add(embeddings)
86
+ return index
87
+
88
+ def retrieve_chunks(query, embedder, index, chunks, top_k=None):
89
+ if not top_k:
90
+ top_k = min(20, len(chunks))
91
+ query_vec = embedder.encode([query])
92
+ faiss.normalize_L2(query_vec)
93
+ scores, indices = index.search(query_vec, top_k)
94
+ return [chunks[i] for i in indices[0]]
95
+
96
+ def build_safe_context(retrieved_chunks, max_chars=12000):
97
+ context = ""
98
+ used = 0
99
+ for c in retrieved_chunks[:3]:
100
+ block = f"(Page {c['page']}) {c['text']}\n\n"
101
+ context += block
102
+ used += len(block)
103
+ for c in retrieved_chunks[3:]:
104
+ block = f"(Page {c['page']}) {c['text']}\n\n"
105
+ if used + len(block) > max_chars:
106
+ break
107
+ context += block
108
+ used += len(block)
109
+ return context
110
+
111
+ def generate_answer(query, context):
112
+ client = Groq()
113
+ prompt = f"""
114
+ You are a document-based assistant.
115
+ Use the context to answer the question clearly.
116
+ If the answer is partially available, summarize it.
117
+ If the answer is not present, you may say 'Not found in the document'.
118
+
119
+ Context:
120
+ {context}
121
+
122
+ Question:
123
+ {query}
124
+ """
125
+ response = client.chat.completions.create(
126
+ model="llama-3.1-8b-instant",
127
+ messages=[{"role": "user", "content": prompt}],
128
+ temperature=0.3
129
+ )
130
+ return response.choices[0].message.content
131
+
132
+ # ==========================================================
133
+ # APP LOGIC
134
+ # ==========================================================
135
+ if uploaded_file:
136
+ with st.spinner("📄 Reading document..."):
137
+ file_name = uploaded_file.name
138
+
139
+ with open(file_name, "wb") as f:
140
+ f.write(uploaded_file.getbuffer())
141
+
142
+ if file_name.lower().endswith(".pdf"):
143
+ pages = read_pdf_with_plumber(file_name)
144
+ elif file_name.lower().endswith(".docx"):
145
+ pages = read_word(file_name)
146
+ elif file_name.lower().endswith(".xlsx"):
147
+ pages = read_excel(file_name)
148
+ else:
149
+ st.error("Unsupported file type")
150
+
151
+ with st.spinner("✂️ Chunking & embedding document..."):
152
+ chunks = chunk_text(pages)
153
+ tokenize_chunks(chunks)
154
+ embedder, embeddings = create_embeddings(chunks)
155
+ index = store_embeddings(embeddings)
156
+
157
+ st.success("✅ Document indexed successfully")
158
+
159
+ query = st.text_input("❓ Ask a question")
160
+
161
+ if query:
162
+ with st.spinner("🤖 Generating answer..."):
163
+ retrieved_chunks = retrieve_chunks(query, embedder, index, chunks)
164
+ context = build_safe_context(retrieved_chunks)
165
+ answer = generate_answer(query, context)
166
+
167
+ st.markdown("### ✅ Answer")
168
+ st.write(answer)