Kakaarot commited on
Commit
dd37de0
·
verified ·
1 Parent(s): e559cdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -147
app.py CHANGED
@@ -1,174 +1,311 @@
1
  import gradio as gr
2
- import google.generativeai as genai
3
- import numpy as np
4
- import pandas as pd
5
- import chromadb
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- from retry import retry
8
  import os
9
  import json
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # API Key Validation
12
- api_key = os.getenv("GEMINI_API_KEY")
13
- if not api_key:
14
- raise ValueError("GEMINI_API_KEY environment variable not set")
15
- genai.configure(api_key=api_key)
16
-
17
- # Precomputed data and embeddings
18
- articles = [
19
- "Climate change accelerates, with 2024 as the hottest year. Rising sea levels threaten coastal cities.",
20
- "Renewable energy grows, but fossil fuels dominate in developing nations.",
21
- "UN report urges action to cut greenhouse gas emissions by 2030.",
22
- "Extreme weather like hurricanes and wildfires is linked to climate change.",
23
- "Amazon deforestation slows, but illegal logging harms carbon sinks.",
24
- "Electric vehicle adoption rises, reducing emissions globally.",
25
- "Coral reefs face bleaching from rising ocean temperatures."
26
- ]
27
-
28
- # Generate embeddings - with corrected model name
29
- embedding_model = "models/embedding-001" # Correct model name
30
- df = pd.DataFrame({"article": articles})
31
-
32
- @retry(tries=3, delay=2, backoff=2)
33
- def get_embedding(text):
34
  try:
35
- result = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
36
- # Correct way to access embedding
37
- embedding = result.embedding
38
- return embedding
39
  except Exception as e:
40
- print(f"Embedding error: {e}")
41
- raise
42
 
43
- # Generate embeddings and ensure they're in the correct format
44
- df["embedding"] = df["article"].apply(get_embedding)
 
 
 
 
 
45
 
46
- # Initialize ChromaDB with proper error handling
47
- client_db = chromadb.Client()
48
- collection = client_db.get_or_create_collection("news_articles")
 
 
 
 
 
 
 
49
 
50
- # Clear existing data to avoid duplicates
51
- try:
52
- collection.delete(ids=[str(i) for i in range(len(df))])
53
- except Exception:
54
- pass # Collection might be empty
55
 
56
- # Add documents with error handling
57
- for idx, row in df.iterrows():
 
 
 
 
58
  try:
59
- collection.add(
60
- documents=[row["article"]],
61
- embeddings=[row["embedding"]],
62
- ids=[str(idx)]
 
 
63
  )
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
- print(f"Error adding document {idx}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Semantic Search
68
- @retry(tries=3, delay=2, backoff=2)
69
- def search_articles(query, top_k=3):
 
 
70
  try:
71
- query_embedding = get_embedding(query)
72
- results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
73
- if not results["ids"][0]:
74
- return []
75
- indices = [int(id) for id in results["ids"][0]]
76
- return df.iloc[indices]["article"].tolist()
77
  except Exception as e:
78
- print(f"Search error: {e}")
79
- return []
80
 
81
- # RAG and Structured Q&A
82
- generation_model = genai.GenerativeModel("gemini-1.5-pro") # Corrected model name
83
 
84
- @retry(tries=3, delay=2, backoff=2)
85
- def generate_response(query, articles, system_message):
86
- if not articles:
87
- return "No relevant articles found.", json.dumps({"error": "No relevant articles found."})
88
-
89
- context = "\n".join(articles)
90
- prompt = f"""
91
- {system_message}
92
- Based on the following articles, provide a concise summary (under 100 words) and a structured JSON response with 'question', 'answer', and 'source'. Use only the provided context.
93
-
94
- Articles:
95
- {context}
96
-
97
- Query: {query}
98
-
99
- Response:
100
- - Summary:
101
- - JSON:
102
- """
103
-
104
  try:
105
- response = generation_model.generate_content(
106
- prompt,
107
- generation_config={
108
- "temperature": 0.7,
109
- "top_p": 0.95,
110
- "max_output_tokens": 1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  },
112
- stream=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
-
115
- full_text = response.text
116
-
117
- # Robust parsing
118
- summary = "Summary not generated."
119
- if "- Summary:" in full_text:
120
- summary_start = full_text.find("- Summary:") + len("- Summary:")
121
- summary_end = full_text.find("- JSON:", summary_start)
122
- if summary_end > summary_start:
123
- summary = full_text[summary_start:summary_end].strip()
124
-
125
- qa_json = "{}"
126
- if "- JSON:" in full_text:
127
- json_start = full_text.find("- JSON:") + len("- JSON:")
128
- qa_json_text = full_text[json_start:].strip()
129
- # Clean up the JSON string - remove markdown code blocks
130
- qa_json_text = qa_json_text.replace("``````", "").strip()
131
-
132
- try:
133
- qa = json.loads(qa_json_text)
134
- qa_json = json.dumps(qa, indent=2)
135
- except json.JSONDecodeError:
136
- qa_json = json.dumps({"error": "Failed to parse JSON response.", "raw_text": qa_json_text})
137
-
138
- return summary, qa_json
139
  except Exception as e:
140
- print(f"RAG error: {e}")
141
- return f"Error generating response: {str(e)}", json.dumps({"error": f"Failed to generate response: {str(e)}"})
142
-
143
- def respond(message, history, system_message="You are a news summarizer and Q&A assistant.", max_tokens=512, temperature=0.7, top_p=0.95):
144
- articles = search_articles(message)
145
- summary, qa = generate_response(message, articles, system_message)
146
-
147
- # Format articles for display
148
- articles_text = "\n".join([f"- {article}" for article in articles]) if articles else "None found"
149
-
150
- response = (
151
- "**Relevant Articles:**\n"
152
- f"{articles_text}\n\n"
153
- "**Summary:**\n"
154
- f"{summary}\n\n"
155
- "**Structured Q&A:**\n"
156
- f"{qa}"
157
- )
158
- yield response
159
-
160
- # Gradio ChatInterface
161
- demo = gr.ChatInterface(
162
- fn=respond,
163
- additional_inputs=[
164
- gr.Textbox(value="You are a news summarizer and Q&A assistant.", label="System message"),
165
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
166
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
167
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  ],
169
- title="Semantic News Summarizer and Q&A Chatbot",
170
- description="Ask about climate change (e.g., 'What are the latest impacts?') to get articles, a summary, and a structured JSON response. Built for the 5-day Gen AI Intensive Course with Google (March 31–April 4, 2025)."
171
  )
172
 
 
173
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
174
  demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
 
2
  import os
3
  import json
4
+ import faiss
5
+ import numpy as np
6
+ import google.generativeai as genai
7
+ from newsapi import NewsApiClient
8
+ from sentence_transformers import SentenceTransformer
9
+ from typing import List, Dict, Any, Optional, Union
10
+
11
+ # --- Configuration ---
12
+ # !! IMPORTANT !! Set these as Hugging Face Space Secrets
13
+ # Go to your Space > Settings > Secrets > Add secret
14
+ NEWS_API_KEY = os.getenv('NEWS_API_KEY')
15
+ GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
16
 
17
+ if not NEWS_API_KEY:
18
+ print("Warning: NEWS_API_KEY secret not found.")
19
+ # Optionally raise an error or handle gracefully in the UI
20
+ if not GOOGLE_API_KEY:
21
+ print("Warning: GOOGLE_API_KEY secret not found.")
22
+ # Optionally raise an error or handle gracefully in the UI
23
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
+ # Configure Google Generative AI only if the key is present
26
+ genai.configure(api_key=GOOGLE_API_KEY)
 
 
27
  except Exception as e:
28
+ print(f"Error configuring Google Generative AI: {e}")
29
+ # Handle configuration error
30
 
31
+ # --- Constants ---
32
+ EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # Lightweight embedding model
33
+ LLM_MODEL_NAME = 'gemini-1.5-flash' # Efficient Gemini model
34
+ MAX_ARTICLES_TO_FETCH = 15 # Fetch a bit more for better potential context
35
+ MAX_ARTICLES_TO_PROCESS = 7 # Process a reasonable number for context
36
+ CHUNK_SIZE = 500 # Approximate characters per text chunk
37
+ TOP_K_CHUNKS = 4 # Number of relevant chunks for LLM context
38
 
39
+ # --- Global Variables / Models (Load Once) ---
40
+ embedding_model = None
41
+ if GOOGLE_API_KEY: # Only load models if keys are likely set
42
+ try:
43
+ print("Loading embedding model...")
44
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
45
+ print("Embedding model loaded.")
46
+ except Exception as e:
47
+ print(f"Error loading Sentence Transformer model '{EMBEDDING_MODEL_NAME}': {e}")
48
+ # The app might still run but RAG will fail
49
 
50
+ # --- Helper Functions (Adapted from previous script) ---
 
 
 
 
51
 
52
+ def fetch_news(topic: str) -> List[Dict[str, Any]]:
53
+ """Fetches recent news articles for a given topic using NewsAPI."""
54
+ if not NEWS_API_KEY:
55
+ print("News API key missing.")
56
+ return []
57
+ print(f"Fetching news for topic: {topic}...")
58
  try:
59
+ newsapi = NewsApiClient(api_key=NEWS_API_KEY)
60
+ top_headlines = newsapi.get_everything(
61
+ q=topic,
62
+ language='en',
63
+ sort_by='relevancy',
64
+ page_size=MAX_ARTICLES_TO_FETCH
65
  )
66
+ articles = top_headlines.get('articles', [])
67
+ valid_articles = [
68
+ {
69
+ "title": article.get("title"),
70
+ "content": article.get("content") or article.get("description", ""),
71
+ "url": article.get("url")
72
+ }
73
+ for article in articles if article.get("content") or article.get("description")
74
+ ][:MAX_ARTICLES_TO_PROCESS] # Limit here
75
+ print(f"Fetched {len(valid_articles)} valid articles.")
76
+ return valid_articles
77
  except Exception as e:
78
+ print(f"Error fetching news: {e}")
79
+ return []
80
+
81
+ def chunk_text(text: str, size: int) -> List[str]:
82
+ """Splits text into chunks."""
83
+ chunks = []
84
+ start = 0
85
+ while start < len(text):
86
+ end = start + size
87
+ pos = text.rfind('.', start, min(end + 50, len(text)))
88
+ if pos != -1 and pos > start + size // 2:
89
+ end = pos + 1
90
+ chunks.append(text[start:end].strip())
91
+ start = end
92
+ return [chunk for chunk in chunks if chunk]
93
+
94
+ def build_vector_store(articles: List[Dict[str, Any]], model: SentenceTransformer):
95
+ """Creates embeddings and builds an in-memory FAISS index."""
96
+ if model is None:
97
+ print("Embedding model not loaded. Cannot build vector store.")
98
+ return None, [], []
99
+ print("Building vector store...")
100
+ all_chunks = []
101
+ metadata = []
102
+ for i, article in enumerate(articles):
103
+ if article.get('content'):
104
+ chunks = chunk_text(article['content'], CHUNK_SIZE)
105
+ for chunk in chunks:
106
+ all_chunks.append(chunk)
107
+ metadata.append({"article_index": i, "url": article.get('url'), "title": article.get('title')})
108
 
109
+ if not all_chunks:
110
+ print("No text content found to build vector store.")
111
+ return None, [], []
112
+
113
+ print(f"Generated {len(all_chunks)} chunks. Creating embeddings...")
114
  try:
115
+ embeddings = model.encode(all_chunks, show_progress_bar=False) # Progress bar can be messy in logs
116
+ dimension = embeddings.shape[1]
117
+ index = faiss.IndexFlatL2(dimension)
118
+ index.add(np.array(embeddings).astype('float32'))
119
+ print("Vector store built successfully.")
120
+ return index, all_chunks, metadata
121
  except Exception as e:
122
+ print(f"Error creating embeddings or FAISS index: {e}")
123
+ return None, [], []
124
 
 
 
125
 
126
+ def retrieve_context(query: str, index: faiss.Index, chunks: List[str], metadata: List[Dict], model: SentenceTransformer, top_k: int) -> str:
127
+ """Retrieves the most relevant text chunks."""
128
+ if model is None or index is None or index.ntotal == 0:
129
+ return "No relevant context found (vector store/model unavailable)."
130
+
131
+ print(f"Retrieving top {top_k} relevant chunks for query: '{query}'...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  try:
133
+ query_embedding = model.encode([query], show_progress_bar=False)
134
+ query_embedding_np = np.array(query_embedding).astype('float32')
135
+ distances, indices = index.search(query_embedding_np, min(top_k, index.ntotal)) # Ensure k <= index.ntotal
136
+
137
+ context_parts = []
138
+ seen_urls = set()
139
+ retrieved_sources = [] # Track sources used in context
140
+
141
+ for i, idx in enumerate(indices[0]):
142
+ if 0 <= idx < len(chunks):
143
+ chunk_text = chunks[idx]
144
+ meta = metadata[idx]
145
+ source_info = f"(Source: {meta.get('url', 'N/A')})"
146
+ full_info = ""
147
+ if meta.get('url') and meta['url'] not in seen_urls:
148
+ full_info = f"From '{meta.get('title', 'Untitled')}':\n{chunk_text}\n{source_info}"
149
+ seen_urls.add(meta['url'])
150
+ if meta.get('url'): retrieved_sources.append(meta['url'])
151
+ else:
152
+ full_info = f"{chunk_text}\n{source_info}"
153
+ # Add source URL if available and not already added from this chunk group
154
+ if meta.get('url') and meta['url'] not in seen_urls:
155
+ seen_urls.add(meta['url'])
156
+ if meta.get('url'): retrieved_sources.append(meta['url'])
157
+
158
+ context_parts.append(full_info)
159
+
160
+ if not context_parts:
161
+ return "No relevant context found matching the query."
162
+
163
+ print(f"Retrieved {len(context_parts)} context parts.")
164
+ # Return context and the list of sources used in that context
165
+ return "\n\n".join(context_parts), list(set(retrieved_sources)) # Use set for uniqueness
166
+ except Exception as e:
167
+ print(f"Error during context retrieval: {e}")
168
+ return "Error retrieving context.", []
169
+
170
+
171
+ def generate_structured_summary(context: str, topic: str) -> Optional[Dict[str, Any]]:
172
+ """Generates a summary using Gemini with structured output."""
173
+ if not GOOGLE_API_KEY:
174
+ print("Google API Key missing. Cannot generate summary.")
175
+ return None
176
+ print("Generating structured summary with LLM...")
177
+ try:
178
+ model = genai.GenerativeModel(LLM_MODEL_NAME)
179
+ json_schema = {
180
+ "type": "object",
181
+ "properties": {
182
+ "topic": {"type": "string"},
183
+ "summary_points": {"type": "array", "items": {"type": "string"}},
184
+ "mentioned_sources": {"type": "array", "items": {"type": "string", "format": "uri"}}
185
  },
186
+ "required": ["topic", "summary_points", "mentioned_sources"]
187
+ }
188
+ prompt = f"""
189
+ Analyze the following retrieved context about '{topic}'. Create a concise summary highlighting the key information.
190
+ Extract the main points and list the unique source URLs mentioned ONLY in the provided context below.
191
+ Respond ONLY with a valid JSON object matching this schema:
192
+
193
+ Schema:
194
+ {json.dumps(json_schema, indent=2)}
195
+
196
+ Retrieved Context:
197
+ ---
198
+ {context}
199
+ ---
200
+
201
+ JSON Output:
202
+ """
203
+
204
+ response = model.generate_content(
205
+ prompt,
206
+ generation_config=genai.types.GenerationConfig(
207
+ response_mime_type="application/json"
208
+ )
209
  )
210
+ summary_json = json.loads(response.text)
211
+ print("LLM generation successful.")
212
+ return summary_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  except Exception as e:
214
+ print(f"Error during LLM generation or JSON parsing: {e}")
215
+ try:
216
+ # Try to log the raw response if possible for debugging
217
+ print(f"LLM Raw Response Text (if available): {response.text}")
218
+ except:
219
+ pass
220
+ return None
221
+
222
+ # --- Main Gradio Function ---
223
+ def summarize_news_interface(topic: str) -> Union[Dict, str]:
224
+ """Orchestrates the news summarization process for the Gradio interface."""
225
+ print(f"\n--- Processing request for topic: {topic} ---")
226
+ if not topic:
227
+ return {"error": "Please enter a topic."}
228
+ if not NEWS_API_KEY or not GOOGLE_API_KEY:
229
+ return {"error": "API Key secrets are not configured correctly in this Space."}
230
+ if embedding_model is None:
231
+ return {"error": "Embedding model could not be loaded. RAG is disabled."}
232
+
233
+ # 1. Fetch News
234
+ articles = fetch_news(topic)
235
+ if not articles:
236
+ return {"error": f"Could not fetch any news articles for '{topic}'. Please try a different topic or check NewsAPI key."}
237
+
238
+ # 2. Build Vector Store (RAG - Embeddings & Indexing)
239
+ vector_index, text_chunks, chunk_metadata = build_vector_store(articles, embedding_model)
240
+ if vector_index is None:
241
+ # Fallback or error - here we'll indicate RAG failed but might proceed without it later if desired
242
+ return {"error": "Could not build vector store (likely no usable article content). RAG step failed."}
243
+
244
+ # 3. Retrieve Relevant Context (RAG - Retrieval)
245
+ context_result = retrieve_context(topic, vector_index, text_chunks, chunk_metadata, embedding_model, TOP_K_CHUNKS)
246
+
247
+ # Check if retrieve_context returned a tuple (context, sources) or an error string
248
+ if isinstance(context_result, tuple):
249
+ retrieved_context, sources_in_context = context_result
250
+ print(f"Context retrieved successfully. Sources in context: {len(sources_in_context)}")
251
+ else: # Handle error string case
252
+ retrieved_context = context_result # Contains the error message
253
+ sources_in_context = []
254
+ print(f"Context retrieval issue: {retrieved_context}")
255
+ # Decide how to proceed. For now, we'll try generating without specific context.
256
+ # A better approach might be to summarize top articles directly, or just show the error.
257
+ # For simplicity, we will show an error JSON
258
+ return {"error": "Failed to retrieve relevant context via RAG.", "details": retrieved_context}
259
+
260
+
261
+ # 4. Generate Structured Summary (Document Understanding + Structured Output)
262
+ # Pass only the sources found in the *retrieved context* to the LLM if needed,
263
+ # but the current prompt asks it to extract from the context itself.
264
+ summary_output = generate_structured_summary(retrieved_context, topic)
265
+
266
+ if summary_output:
267
+ # Ensure the sources list in the JSON only contains those from the context
268
+ # The LLM should ideally handle this based on the prompt, but we can double-check/override.
269
+ # summary_output['mentioned_sources'] = sources_in_context # Optional override
270
+ print("--- Request processing complete ---")
271
+ return summary_output
272
+ else:
273
+ print("--- Request processing failed at LLM step ---")
274
+ # Provide specific error if LLM failed
275
+ return {"error": "Failed to generate summary using the LLM.", "details": "Check logs for potential API errors or LLM issues."}
276
+
277
+ # --- Gradio Interface Definition ---
278
+ demo = gr.Interface(
279
+ fn=summarize_news_interface,
280
+ inputs=gr.Textbox(
281
+ label="Enter News Topic",
282
+ placeholder="e.g., latest advancements in renewable energy, Premier League results, space exploration updates..."
283
+ ),
284
+ outputs=gr.JSON(label="News Digest Summary"),
285
+ title="📰 AI News Digest Generator",
286
+ description=(
287
+ "Enter a topic to get a structured summary of recent news articles.\n"
288
+ "This app uses RAG (Retrieval Augmented Generation) with FAISS/SentenceTransformers "
289
+ "and Google Gemini for summarization.\n"
290
+ "**Requires NEWS_API_KEY and GOOGLE_API_KEY secrets set in the Space settings.**"
291
+ ),
292
+ examples=[
293
+ ["AI in healthcare"],
294
+ ["Electric vehicle market trends"],
295
+ ["Recent archaeological discoveries"]
296
  ],
297
+ allow_flagging='never',
298
+ # theme=gr.themes.Soft() # Optional: adds a theme
299
  )
300
 
301
+ # --- Launch the App ---
302
  if __name__ == "__main__":
303
+ # Check for keys on launch locally (won't hurt on Spaces)
304
+ if not NEWS_API_KEY or not GOOGLE_API_KEY:
305
+ print("\n*** WARNING: API Keys not found as environment variables. ***")
306
+ print("*** Please set NEWS_API_KEY and GOOGLE_API_KEY if running locally. ***")
307
+ print("*** In Hugging Face Spaces, set them as Secrets in Settings. ***\n")
308
+ elif embedding_model is None:
309
+ print("\n*** WARNING: Embedding model failed to load. RAG features will not work. ***\n")
310
+
311
  demo.launch()