amanapk commited on
Commit
7c331c3
·
verified ·
1 Parent(s): b03fe49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -147
app.py CHANGED
@@ -1,168 +1,240 @@
 
1
  import streamlit as st
2
- import pymupdf
3
- import re
4
- import traceback
5
  import faiss
6
  import numpy as np
7
- import requests
8
- from rank_bm25 import BM25Okapi
9
- from sentence_transformers import SentenceTransformer
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain_groq import ChatGroq
12
- import torch
13
- import os
14
-
15
- st.set_page_config(page_title="Financial Insights Chatbot", page_icon="📊", layout="wide")
16
-
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
20
- ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
21
-
22
- try:
23
- llm = ChatGroq(temperature=0, model="llama3-70b-8192", api_key=GROQ_API_KEY)
24
- st.success("✅ LLM initialized successfully. Using llama3-70b-8192")
25
- except Exception as e:
26
- st.error("❌ Failed to initialize Groq LLM.")
27
- traceback.print_exc()
28
-
29
- embedding_model = SentenceTransformer("baconnier/Finance2_embedding_small_en-V1.5", device=device)
30
-
31
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
32
 
33
- def fetch_financial_data(company_ticker):
34
- if not company_ticker:
35
- return "No ticker symbol provided. Please enter a valid company ticker."
36
 
37
- try:
38
- overview_url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
39
- overview_response = requests.get(overview_url)
40
-
41
- if overview_response.status_code == 200:
42
- overview_data = overview_response.json()
43
- market_cap = overview_data.get("MarketCapitalization", "N/A")
44
- else:
45
- return "Error fetching company overview."
46
 
47
- income_url = f"https://www.alphavantage.co/query?function=INCOME_STATEMENT&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
48
- income_response = requests.get(income_url)
49
 
50
- if income_response.status_code == 200:
51
- income_data = income_response.json()
52
- annual_reports = income_data.get("annualReports", [])
53
- revenue = annual_reports[0].get("totalRevenue", "N/A") if annual_reports else "N/A"
54
- else:
55
- return "Error fetching income statement."
56
 
57
- return f"Market Cap: ${market_cap}\nTotal Revenue: ${revenue}"
58
-
59
- except Exception as e:
60
- traceback.print_exc()
61
- return "Error fetching financial data."
62
 
63
- def extract_and_embed_text(pdf_file):
64
- """Processes PDFs and generates embeddings with GPU acceleration using pymupdf."""
65
- try:
66
- docs, tokenized_texts = [], []
67
-
68
- with pymupdf.open(stream=pdf_file.read(), filetype="pdf") as doc:
69
- full_text = "\n".join(page.get_text("text") for page in doc)
70
- chunks = text_splitter.split_text(full_text)
71
- for chunk in chunks:
72
- docs.append(chunk)
73
- tokenized_texts.append(chunk.split())
74
-
75
- embeddings = embedding_model.encode(docs, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
76
-
77
- embedding_dim = embeddings.shape[1]
78
- index = faiss.IndexHNSWFlat(embedding_dim, 32)
79
- index.add(embeddings)
80
-
81
- bm25 = BM25Okapi(tokenized_texts)
82
-
83
- return docs, embeddings, index, bm25
84
- except Exception as e:
85
- traceback.print_exc()
86
- return [], [], None, None
87
-
88
- def retrieve_relevant_docs(user_query, docs, index, bm25):
89
- """Hybrid search using FAISS cosine similarity & BM25 keyword retrieval."""
90
- query_embedding = embedding_model.encode(user_query, convert_to_numpy=True, normalize_embeddings=True)
91
- _, faiss_indices = index.search(np.array([query_embedding]), 8)
92
- bm25_scores = bm25.get_scores(user_query.split())
93
- bm25_indices = np.argsort(bm25_scores)[::-1][:8]
94
- combined_indices = list(set(faiss_indices[0]) | set(bm25_indices))
95
-
96
- return [docs[i] for i in combined_indices[:3]]
97
 
98
- def generate_response(user_query, pdf_ticker, ai_ticker, mode, uploaded_file):
 
99
  try:
100
- if mode == "📄 PDF Upload Mode":
101
- docs, embeddings, index, bm25 = extract_and_embed_text(uploaded_file)
102
- if not docs:
103
- return "❌ Error extracting text from PDF."
104
-
105
- retrieved_docs = retrieve_relevant_docs(user_query, docs, index, bm25)
106
- context = "\n\n".join(retrieved_docs)
107
- prompt = f"Summarize the key financial insights for {pdf_ticker} from this document:\n\n{context}"
108
-
109
- elif mode == "🌍 Live Data Mode":
110
- financial_info = fetch_financial_data(ai_ticker)
111
- prompt = f"Analyze the financial status of {ai_ticker} based on:\n{financial_info}\n\nUser Query: {user_query}"
112
- else:
113
- return "Invalid mode selected."
114
-
115
- response = llm.invoke(prompt)
116
- return response.content
117
  except Exception as e:
118
- traceback.print_exc()
119
- return "Error generating response."
 
 
 
120
 
121
- st.markdown(
122
- "<h1 style='text-align: center; color: #4CAF50;'>📄 FinQuery RAG Chatbot</h1>",
123
- unsafe_allow_html=True
124
- )
125
- st.markdown(
126
- "<h5 style='text-align: center; color: #666;'>Analyze financial reports or fetch live financial data effortlessly!</h5>",
127
- unsafe_allow_html=True
128
- )
129
 
130
- col1, col2 = st.columns(2)
 
 
 
 
131
 
132
- with col1:
133
- st.markdown("### 🏢 **Choose Your Analysis Mode**")
134
- mode = st.radio("", ["📄 PDF Upload Mode", "🌍 Live Data Mode"], horizontal=True)
135
 
136
- with col2:
137
- st.markdown("### 🔎 **Enter Your Query**")
138
- user_query = st.text_input("💬 What financial insights are you looking for?")
139
 
140
- st.markdown("---")
141
- uploaded_file, company_ticker = None, None
142
 
143
- if mode == "📄 PDF Upload Mode":
144
- st.markdown("### 📂 Upload Your Financial Report")
145
- uploaded_file = st.file_uploader("🔼 Upload PDF Report", type=["pdf"])
146
- company_ticker = None
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
- st.markdown("### 🌍 Live Market Data")
150
- company_ticker = st.text_input("🏢 Enter Company Ticker Symbol", placeholder="e.g., AAPL, MSFT")
151
- uploaded_file = None
152
-
153
- # 🎯 Submit Button
154
- if st.button("🚀 Analyze Now"):
155
- if mode == "📄 PDF Upload Mode" and not uploaded_file:
156
- st.error("❌ Please upload a PDF file.")
157
- elif mode == "🌍 Live Data Mode" and not company_ticker:
158
- st.error("❌ Please enter a valid company ticker symbol.")
159
- else:
160
- with st.spinner("🔍 Your Query is Processing, this can take upto 5 - 7 minutes⏳"):
161
- response = generate_response(user_query, company_ticker, mode, uploaded_file)
162
- st.markdown("---")
163
- st.markdown("<h3 style='color: #4CAF50;'>💡 AI Response</h3>", unsafe_allow_html=True)
164
- st.write(response)
165
-
166
- # 📌 Footer
167
- st.markdown("---")
168
 
 
1
+
2
  import streamlit as st
3
+ from google.api_core.client_options import ClientOptions
4
+ from google.cloud import documentai_v1
5
+ from sentence_transformers import SentenceTransformer
6
  import faiss
7
  import numpy as np
8
+ import textwrap
9
+ import os
10
+ import json
11
+ import tempfile
12
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
15
 
 
 
16
 
 
 
 
 
 
 
17
 
18
+ # ------------------- Secure Credential Loading for Hugging Face ------------------- #
19
+ # This section loads the Service Account from Hugging Face Secrets for ADC
 
 
 
20
 
21
+ # 1. Load the Service Account JSON string from the environment variable (secret)
22
+ gcp_credentials_json_str = os.getenv("GCP_CREDENTIALS_JSON")
23
+ project_id = "wise-env-461717-t5" # Initialize project_id
24
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # 2. Check if the secret is present
27
+ if gcp_credentials_json_str:
28
  try:
29
+ # --- FIX: Write to the /tmp/ directory, which is writable on Hugging Face Spaces ---
30
+ credentials_file_path = "/tmp/gcp_service_account.json"
31
+
32
+ # 3. Write the JSON string to the file in the temporary directory
33
+ with open(credentials_file_path, "w") as f:
34
+ f.write(gcp_credentials_json_str)
35
+
36
+ # 4. Set the environment variable to point to this file
37
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file_path
38
+
39
+ # Extract project_id from the credentials for convenience
40
+ creds_dict = json.loads(gcp_credentials_json_str)
41
+ project_id = creds_dict.get("project_id")
42
+
 
 
 
43
  except Exception as e:
44
+ st.error(f"🚨 Failed to process GCP credentials: {e}")
45
+ st.stop()
46
+ else:
47
+ st.error("🚨 GCP_CREDENTIALS_JSON secret not found! Please add it to your Hugging Face Space settings.")
48
+ st.stop()
49
 
 
 
 
 
 
 
 
 
50
 
51
+ # ------------------- Configuration ------------------- #
52
+ # Project ID is now dynamically loaded from the service account
53
+ if not project_id:
54
+ st.error("🚨 Project ID could not be found in the GCP credentials.")
55
+ st.stop()
56
 
57
+ # You still need to provide your Processor ID and location
58
+ processor_id = "86a7eec52bbb9616" # <-- REPLACE WITH YOUR PROCESSOR ID
59
+ location = "us" # e.g., "us" or "eu"
60
 
 
 
 
61
 
 
 
62
 
 
 
 
 
63
 
64
+ # ------------------- Google Document AI Client (Uses ADC) ------------------- #
65
+ # The client now automatically finds and uses the credentials set above
66
+ try:
67
+ opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com")
68
+ docai_client = documentai_v1.DocumentProcessorServiceClient(client_options=opts)
69
+ full_processor_name = docai_client.processor_path(project_id, location, processor_id)
70
+ except Exception as e:
71
+ st.error(f"Error initializing Document AI client: {e}")
72
+ st.stop()
73
+
74
+
75
+ @st.cache_resource
76
+ def load_embedding_model():
77
+ # Use a writable cache directory
78
+ cache_dir = "/tmp/hf_cache"
79
+ os.makedirs(cache_dir, exist_ok=True)
80
+
81
+ # Set Hugging Face environment variables
82
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
83
+ os.environ["HF_HOME"] = cache_dir
84
+
85
+ # Load embedding model
86
+ return SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
87
+
88
+
89
+ # ------------------- Utility Functions ------------------- #
90
+ def chunk_text(text, max_chars=500):
91
+ return textwrap.wrap(text, max_chars)
92
+ def extract_text_with_documentai(file_path):
93
+ with open(file_path, "rb") as f:
94
+ content = f.read()
95
+ raw_document = documentai_v1.RawDocument(content=content, mime_type="application/pdf")
96
+ request = documentai_v1.ProcessRequest(name=full_processor_name, raw_document=raw_document)
97
+ result = docai_client.process_document(request=request)
98
+ document = result.document
99
+ return document.text
100
+
101
+ def build_index(text):
102
+ text_chunks = chunk_text(text)
103
+ embeddings = embed_model.encode(text_chunks)
104
+ dim = embeddings.shape[1]
105
+ index = faiss.IndexFlatL2(dim)
106
+ index.add(np.array(embeddings))
107
+ return index, text_chunks
108
+
109
+ def retrieve_context(query, index, text_chunks, top_k=5):
110
+ query_embed = embed_model.encode([query])
111
+ distances, indices = index.search(np.array(query_embed), top_k)
112
+ return [text_chunks[i] for i in indices[0]]
113
+
114
+ # ------------------- Gemini API Functions ------------------- #
115
+ def ask_groq_agent(query, context):
116
+ prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
117
+ response = requests.post(
118
+ "https://api.groq.com/openai/v1/chat/completions",
119
+ headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
120
+ json={
121
+ "model": "llama3-70b-8192",
122
+ "messages": [{"role": "user", "content": prompt}],
123
+ "temperature": 0.3
124
+ }
125
+ )
126
+ return response.json()["choices"][0]["message"]["content"]
127
+ def get_summary(text):
128
+ prompt = f"Please provide a concise summary of the following document:\n\n{text[:4000]}"
129
+ response = requests.post(
130
+ "https://api.groq.com/openai/v1/chat/completions",
131
+ headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
132
+ json={
133
+ "model": "llama3-70b-8192",
134
+ "messages": [{"role": "user", "content": prompt}],
135
+ "temperature": 0.3
136
+ }
137
+ )
138
+ return response.json()["choices"][0]["message"]["content"]
139
+
140
+
141
+ def generate_flashcards(text_chunks):
142
+ joined_text = "\n".join(text_chunks)
143
+ prompt = (
144
+ "Generate 5 helpful flashcards from the following content. "
145
+ "Use the format exactly like this:\n\n"
146
+ "Q: What is ...?\nA: ...\n\nQ: How does ...?\nA: ...\n\n"
147
+ "Text:\n" + joined_text
148
+ )
149
+
150
+ response = requests.post(
151
+ "https://api.groq.com/openai/v1/chat/completions",
152
+ headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
153
+ json={
154
+ "model": "llama3-70b-8192",
155
+ "messages": [{"role": "user", "content": prompt}],
156
+ "temperature": 0.5
157
+ }
158
+ )
159
+ content = response.json()["choices"][0]["message"]["content"]
160
+
161
+ flashcards = []
162
+ question = None
163
+ for line in content.strip().splitlines():
164
+ line = line.strip()
165
+ if line.lower().startswith("q:"):
166
+ question = line[2:].strip()
167
+ elif line.lower().startswith("a:") and question:
168
+ answer = line[2:].strip()
169
+ flashcards.append({"question": question, "answer": answer})
170
+ question = None
171
+ return flashcards
172
+
173
+ st.title("📄 PDF AI Assistant (Groq + DocAI)")
174
+
175
+ if "index" not in st.session_state:
176
+ st.session_state.index = None
177
+ st.session_state.text_chunks = []
178
+ st.session_state.raw_text = ""
179
+
180
+ with st.sidebar:
181
+ st.header("📤 Upload PDF")
182
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
183
+
184
+ if uploaded_file is not None:
185
+ try:
186
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
187
+ tmp_file.write(uploaded_file.read())
188
+ tmp_file.flush()
189
+ tmp_path = tmp_file.name
190
+
191
+ # DEBUG: File info
192
+ st.write("Saved file at:", tmp_path)
193
+ st.write("File size:", os.path.getsize(tmp_path), "bytes")
194
+ st.write("File exists:", os.path.exists(tmp_path))
195
+
196
+ with st.spinner("Extracting text using Document AI..."):
197
+ raw_text = extract_text_with_documentai(tmp_path)
198
+ index, text_chunks = build_index(raw_text)
199
+ st.session_state.index = index
200
+ st.session_state.text_chunks = text_chunks
201
+ st.session_state.raw_text = raw_text
202
+ st.success("✅ Document processed successfully.")
203
+ except Exception as e:
204
+ st.error(f"Error: {e}")
205
+ finally:
206
+ os.unlink(tmp_path)
207
+
208
+
209
+ # ------------------- Q&A Interface ------------------- #
210
+ st.subheader("❓ Ask Questions")
211
+ if st.session_state.index:
212
+ question = st.text_input("Enter your question")
213
+ if st.button("Ask"):
214
+ context = "\n\n".join(retrieve_context(question, st.session_state.index, st.session_state.text_chunks))
215
+ answer = ask_groq_agent(question, context)
216
+ st.markdown(f"**Answer:** {answer}")
217
+ else:
218
+ st.info("Upload a PDF to start asking questions.")
219
+
220
+ # ------------------- Summary Interface ------------------- #
221
+ st.subheader("📝 Document Summary")
222
+ if st.session_state.text_chunks:
223
+ if st.button("Generate Summary"):
224
+ with st.spinner("Generating summary..."):
225
+ summary = get_summary(" ".join(st.session_state.text_chunks))
226
+ st.markdown(summary)
227
+ else:
228
+ st.info("Upload a PDF to get a summary.")
229
+
230
+ # ------------------- Flashcards ------------------- #
231
+ st.subheader("🧠 Flashcards")
232
+ if st.session_state.text_chunks:
233
+ if st.button("Generate Flashcards"):
234
+ with st.spinner("Generating flashcards..."):
235
+ flashcards = generate_flashcards(st.session_state.text_chunks)
236
+ for fc in flashcards:
237
+ st.markdown(f"**Q: {fc['question']}**\n\nA: {fc['answer']}")
238
  else:
239
+ st.info("Upload a PDF to generate flashcards.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240