Kakaarot's picture
Update app.py
7fab8c3 verified
raw
history blame
13.5 kB
import gradio as gr
import os
import json
import faiss
import numpy as np
import google.generativeai as genai
from newsapi import NewsApiClient
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional, Union
# --- Configuration ---
NEWS_API_KEY = os.getenv('NEWS_API_KEY')
GOOGLE_API_KEY = os.getenv('GEMINI_API_KEY')
if not NEWS_API_KEY:
print("Warning: NEWS_API_KEY secret not found.")
# Optionally raise an error or handle gracefully in the UI
if not GOOGLE_API_KEY:
print("Warning: GOOGLE_API_KEY secret not found.")
# Optionally raise an error or handle gracefully in the UI
else:
try:
# Configure Google Generative AI only if the key is present
genai.configure(api_key=GOOGLE_API_KEY)
except Exception as e:
print(f"Error configuring Google Generative AI: {e}")
# Handle configuration error
# --- Constants ---
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # Lightweight embedding model
LLM_MODEL_NAME = 'gemini-1.5-flash' # Efficient Gemini model
MAX_ARTICLES_TO_FETCH = 15 # Fetch a bit more for better potential context
MAX_ARTICLES_TO_PROCESS = 7 # Process a reasonable number for context
CHUNK_SIZE = 500 # Approximate characters per text chunk
TOP_K_CHUNKS = 4 # Number of relevant chunks for LLM context
# --- Global Variables / Models (Load Once) ---
embedding_model = None
if GOOGLE_API_KEY: # Only load models if keys are likely set
try:
print("Loading embedding model...")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
print("Embedding model loaded.")
except Exception as e:
print(f"Error loading Sentence Transformer model '{EMBEDDING_MODEL_NAME}': {e}")
# The app might still run but RAG will fail
# --- Helper Functions (Adapted from previous script) ---
def fetch_news(topic: str) -> List[Dict[str, Any]]:
"""Fetches recent news articles for a given topic using NewsAPI."""
if not NEWS_API_KEY:
print("News API key missing.")
return []
print(f"Fetching news for topic: {topic}...")
try:
newsapi = NewsApiClient(api_key=NEWS_API_KEY)
top_headlines = newsapi.get_everything(
q=topic,
language='en',
sort_by='relevancy',
page_size=MAX_ARTICLES_TO_FETCH
)
articles = top_headlines.get('articles', [])
valid_articles = [
{
"title": article.get("title"),
"content": article.get("content") or article.get("description", ""),
"url": article.get("url")
}
for article in articles if article.get("content") or article.get("description")
][:MAX_ARTICLES_TO_PROCESS] # Limit here
print(f"Fetched {len(valid_articles)} valid articles.")
return valid_articles
except Exception as e:
print(f"Error fetching news: {e}")
return []
def chunk_text(text: str, size: int) -> List[str]:
"""Splits text into chunks."""
chunks = []
start = 0
while start < len(text):
end = start + size
pos = text.rfind('.', start, min(end + 50, len(text)))
if pos != -1 and pos > start + size // 2:
end = pos + 1
chunks.append(text[start:end].strip())
start = end
return [chunk for chunk in chunks if chunk]
def build_vector_store(articles: List[Dict[str, Any]], model: SentenceTransformer):
"""Creates embeddings and builds an in-memory FAISS index."""
if model is None:
print("Embedding model not loaded. Cannot build vector store.")
return None, [], []
print("Building vector store...")
all_chunks = []
metadata = []
for i, article in enumerate(articles):
if article.get('content'):
chunks = chunk_text(article['content'], CHUNK_SIZE)
for chunk in chunks:
all_chunks.append(chunk)
metadata.append({"article_index": i, "url": article.get('url'), "title": article.get('title')})
if not all_chunks:
print("No text content found to build vector store.")
return None, [], []
print(f"Generated {len(all_chunks)} chunks. Creating embeddings...")
try:
embeddings = model.encode(all_chunks, show_progress_bar=False) # Progress bar can be messy in logs
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype('float32'))
print("Vector store built successfully.")
return index, all_chunks, metadata
except Exception as e:
print(f"Error creating embeddings or FAISS index: {e}")
return None, [], []
def retrieve_context(query: str, index: faiss.Index, chunks: List[str], metadata: List[Dict], model: SentenceTransformer, top_k: int) -> str:
"""Retrieves the most relevant text chunks."""
if model is None or index is None or index.ntotal == 0:
return "No relevant context found (vector store/model unavailable)."
print(f"Retrieving top {top_k} relevant chunks for query: '{query}'...")
try:
query_embedding = model.encode([query], show_progress_bar=False)
query_embedding_np = np.array(query_embedding).astype('float32')
distances, indices = index.search(query_embedding_np, min(top_k, index.ntotal)) # Ensure k <= index.ntotal
context_parts = []
seen_urls = set()
retrieved_sources = [] # Track sources used in context
for i, idx in enumerate(indices[0]):
if 0 <= idx < len(chunks):
chunk_text = chunks[idx]
meta = metadata[idx]
source_info = f"(Source: {meta.get('url', 'N/A')})"
full_info = ""
if meta.get('url') and meta['url'] not in seen_urls:
full_info = f"From '{meta.get('title', 'Untitled')}':\n{chunk_text}\n{source_info}"
seen_urls.add(meta['url'])
if meta.get('url'): retrieved_sources.append(meta['url'])
else:
full_info = f"{chunk_text}\n{source_info}"
# Add source URL if available and not already added from this chunk group
if meta.get('url') and meta['url'] not in seen_urls:
seen_urls.add(meta['url'])
if meta.get('url'): retrieved_sources.append(meta['url'])
context_parts.append(full_info)
if not context_parts:
return "No relevant context found matching the query."
print(f"Retrieved {len(context_parts)} context parts.")
# Return context and the list of sources used in that context
return "\n\n".join(context_parts), list(set(retrieved_sources)) # Use set for uniqueness
except Exception as e:
print(f"Error during context retrieval: {e}")
return "Error retrieving context.", []
def generate_structured_summary(context: str, topic: str) -> Optional[Dict[str, Any]]:
"""Generates a summary using Gemini with structured output."""
if not GOOGLE_API_KEY:
print("Google API Key missing. Cannot generate summary.")
return None
print("Generating structured summary with LLM...")
try:
model = genai.GenerativeModel(LLM_MODEL_NAME)
json_schema = {
"type": "object",
"properties": {
"topic": {"type": "string"},
"summary_points": {"type": "array", "items": {"type": "string"}},
"mentioned_sources": {"type": "array", "items": {"type": "string", "format": "uri"}}
},
"required": ["topic", "summary_points", "mentioned_sources"]
}
prompt = f"""
Analyze the following retrieved context about '{topic}'. Create a concise summary highlighting the key information.
Extract the main points and list the unique source URLs mentioned ONLY in the provided context below.
Respond ONLY with a valid JSON object matching this schema:
Schema:
{json.dumps(json_schema, indent=2)}
Retrieved Context:
---
{context}
---
JSON Output:
"""
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
response_mime_type="application/json"
)
)
summary_json = json.loads(response.text)
print("LLM generation successful.")
return summary_json
except Exception as e:
print(f"Error during LLM generation or JSON parsing: {e}")
try:
# Try to log the raw response if possible for debugging
print(f"LLM Raw Response Text (if available): {response.text}")
except:
pass
return None
# --- Main Gradio Function ---
def summarize_news_interface(topic: str) -> Union[Dict, str]:
"""Orchestrates the news summarization process for the Gradio interface."""
print(f"\n--- Processing request for topic: {topic} ---")
if not topic:
return {"error": "Please enter a topic."}
if not NEWS_API_KEY or not GOOGLE_API_KEY:
return {"error": "API Key secrets are not configured correctly in this Space."}
if embedding_model is None:
return {"error": "Embedding model could not be loaded. RAG is disabled."}
# 1. Fetch News
articles = fetch_news(topic)
if not articles:
return {"error": f"Could not fetch any news articles for '{topic}'. Please try a different topic or check NewsAPI key."}
# 2. Build Vector Store (RAG - Embeddings & Indexing)
vector_index, text_chunks, chunk_metadata = build_vector_store(articles, embedding_model)
if vector_index is None:
# Fallback or error - here we'll indicate RAG failed but might proceed without it later if desired
return {"error": "Could not build vector store (likely no usable article content). RAG step failed."}
# 3. Retrieve Relevant Context (RAG - Retrieval)
context_result = retrieve_context(topic, vector_index, text_chunks, chunk_metadata, embedding_model, TOP_K_CHUNKS)
# Check if retrieve_context returned a tuple (context, sources) or an error string
if isinstance(context_result, tuple):
retrieved_context, sources_in_context = context_result
print(f"Context retrieved successfully. Sources in context: {len(sources_in_context)}")
else: # Handle error string case
retrieved_context = context_result # Contains the error message
sources_in_context = []
print(f"Context retrieval issue: {retrieved_context}")
# Decide how to proceed. For now, we'll try generating without specific context.
# A better approach might be to summarize top articles directly, or just show the error.
# For simplicity, we will show an error JSON
return {"error": "Failed to retrieve relevant context via RAG.", "details": retrieved_context}
# 4. Generate Structured Summary (Document Understanding + Structured Output)
# Pass only the sources found in the *retrieved context* to the LLM if needed,
# but the current prompt asks it to extract from the context itself.
summary_output = generate_structured_summary(retrieved_context, topic)
if summary_output:
# Ensure the sources list in the JSON only contains those from the context
# The LLM should ideally handle this based on the prompt, but we can double-check/override.
# summary_output['mentioned_sources'] = sources_in_context # Optional override
print("--- Request processing complete ---")
return summary_output
else:
print("--- Request processing failed at LLM step ---")
# Provide specific error if LLM failed
return {"error": "Failed to generate summary using the LLM.", "details": "Check logs for potential API errors or LLM issues."}
# --- Gradio Interface Definition ---
demo = gr.Interface(
fn=summarize_news_interface,
inputs=gr.Textbox(
label="Enter News Topic",
placeholder="e.g., latest advancements in renewable energy, Premier League results, space exploration updates..."
),
outputs=gr.JSON(label="News Digest Summary"),
title="📰 AI News Digest Generator",
description=(
"Enter a topic to get a structured summary of recent news articles.\n"
"This app uses RAG (Retrieval Augmented Generation) with FAISS/SentenceTransformers "
"and Google Gemini for summarization.\n"
),
examples=[
["AI in healthcare"],
["Electric vehicle market trends"],
["Recent archaeological discoveries"]
],
allow_flagging='never',
# theme=gr.themes.Soft() # Optional: adds a theme
)
# --- Launch the App ---
if __name__ == "__main__":
# Check for keys on launch locally (won't hurt on Spaces)
if not NEWS_API_KEY or not GOOGLE_API_KEY:
print("\n*** WARNING: API Keys not found as environment variables. ***")
print("*** Please set NEWS_API_KEY and GOOGLE_API_KEY if running locally. ***")
print("*** In Hugging Face Spaces, set them as Secrets in Settings. ***\n")
elif embedding_model is None:
print("\n*** WARNING: Embedding model failed to load. RAG features will not work. ***\n")
demo.launch()