Kakaarot's picture
Update app.py
dc2e285 verified
raw
history blame
6.33 kB
import gradio as gr
import google.generativeai as genai
import numpy as np
import pandas as pd
import chromadb
from sklearn.metrics.pairwise import cosine_similarity
from retry import retry
import os
import json
# API Key Validation
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable not set")
genai.configure(api_key=api_key)
# Precomputed data and embeddings
articles = [
"Climate change accelerates, with 2024 as the hottest year. Rising sea levels threaten coastal cities.",
"Renewable energy grows, but fossil fuels dominate in developing nations.",
"UN report urges action to cut greenhouse gas emissions by 2030.",
"Extreme weather like hurricanes and wildfires is linked to climate change.",
"Amazon deforestation slows, but illegal logging harms carbon sinks.",
"Electric vehicle adoption rises, reducing emissions globally.",
"Coral reefs face bleaching from rising ocean temperatures."
]
# Generate embeddings - with corrected model name
embedding_model = "models/embedding-001" # Correct model name
df = pd.DataFrame({"article": articles})
@retry(tries=3, delay=2, backoff=2)
def get_embedding(text):
try:
result = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
# Correct way to access embedding
embedding = result.embedding
return embedding
except Exception as e:
print(f"Embedding error: {e}")
raise
# Generate embeddings and ensure they're in the correct format
df["embedding"] = df["article"].apply(get_embedding)
# Initialize ChromaDB with proper error handling
client_db = chromadb.Client()
collection = client_db.get_or_create_collection("news_articles")
# Clear existing data to avoid duplicates
try:
collection.delete(ids=[str(i) for i in range(len(df))])
except Exception:
pass # Collection might be empty
# Add documents with error handling
for idx, row in df.iterrows():
try:
collection.add(
documents=[row["article"]],
embeddings=[row["embedding"]],
ids=[str(idx)]
)
except Exception as e:
print(f"Error adding document {idx}: {e}")
# Semantic Search
@retry(tries=3, delay=2, backoff=2)
def search_articles(query, top_k=3):
try:
query_embedding = get_embedding(query)
results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
if not results["ids"][0]:
return []
indices = [int(id) for id in results["ids"][0]]
return df.iloc[indices]["article"].tolist()
except Exception as e:
print(f"Search error: {e}")
return []
# RAG and Structured Q&A
generation_model = genai.GenerativeModel("gemini-1.5-pro") # Corrected model name
@retry(tries=3, delay=2, backoff=2)
def generate_response(query, articles, system_message):
if not articles:
return "No relevant articles found.", json.dumps({"error": "No relevant articles found."})
context = "\n".join(articles)
prompt = f"""
{system_message}
Based on the following articles, provide a concise summary (under 100 words) and a structured JSON response with 'question', 'answer', and 'source'. Use only the provided context.
Articles:
{context}
Query: {query}
Response:
- Summary:
- JSON:
"""
try:
response = generation_model.generate_content(
prompt,
generation_config={
"temperature": 0.7,
"top_p": 0.95,
"max_output_tokens": 1024,
},
stream=False
)
full_text = response.text
# Robust parsing
summary = "Summary not generated."
if "- Summary:" in full_text:
summary_start = full_text.find("- Summary:") + len("- Summary:")
summary_end = full_text.find("- JSON:", summary_start)
if summary_end > summary_start:
summary = full_text[summary_start:summary_end].strip()
qa_json = "{}"
if "- JSON:" in full_text:
json_start = full_text.find("- JSON:") + len("- JSON:")
qa_json_text = full_text[json_start:].strip()
# Clean up the JSON string - remove markdown code blocks
qa_json_text = qa_json_text.replace("``````", "").strip()
try:
qa = json.loads(qa_json_text)
qa_json = json.dumps(qa, indent=2)
except json.JSONDecodeError:
qa_json = json.dumps({"error": "Failed to parse JSON response.", "raw_text": qa_json_text})
return summary, qa_json
except Exception as e:
print(f"RAG error: {e}")
return f"Error generating response: {str(e)}", json.dumps({"error": f"Failed to generate response: {str(e)}"})
def respond(message, history, system_message="You are a news summarizer and Q&A assistant.", max_tokens=512, temperature=0.7, top_p=0.95):
articles = search_articles(message)
summary, qa = generate_response(message, articles, system_message)
# Format articles for display
articles_text = "\n".join([f"- {article}" for article in articles]) if articles else "None found"
response = (
"**Relevant Articles:**\n"
f"{articles_text}\n\n"
"**Summary:**\n"
f"{summary}\n\n"
"**Structured Q&A:**\n"
f"{qa}"
)
yield response
# Gradio ChatInterface
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are a news summarizer and Q&A assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="Semantic News Summarizer and Q&A Chatbot",
description="Ask about climate change (e.g., 'What are the latest impacts?') to get articles, a summary, and a structured JSON response. Built for the 5-day Gen AI Intensive Course with Google (March 31–April 4, 2025)."
)
if __name__ == "__main__":
demo.launch()