|
|
import gradio as gr |
|
|
import google.generativeai as genai |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import chromadb |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from retry import retry |
|
|
import os |
|
|
import json |
|
|
|
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("GEMINI_API_KEY environment variable not set") |
|
|
genai.configure(api_key=api_key) |
|
|
|
|
|
|
|
|
articles = [ |
|
|
"Climate change accelerates, with 2024 as the hottest year. Rising sea levels threaten coastal cities.", |
|
|
"Renewable energy grows, but fossil fuels dominate in developing nations.", |
|
|
"UN report urges action to cut greenhouse gas emissions by 2030.", |
|
|
"Extreme weather like hurricanes and wildfires is linked to climate change.", |
|
|
"Amazon deforestation slows, but illegal logging harms carbon sinks.", |
|
|
"Electric vehicle adoption rises, reducing emissions globally.", |
|
|
"Coral reefs face bleaching from rising ocean temperatures." |
|
|
] |
|
|
|
|
|
|
|
|
embedding_model = "models/embedding-001" |
|
|
df = pd.DataFrame({"article": articles}) |
|
|
|
|
|
@retry(tries=3, delay=2, backoff=2) |
|
|
def get_embedding(text): |
|
|
try: |
|
|
result = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT") |
|
|
|
|
|
embedding = result.embedding |
|
|
return embedding |
|
|
except Exception as e: |
|
|
print(f"Embedding error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
df["embedding"] = df["article"].apply(get_embedding) |
|
|
|
|
|
|
|
|
client_db = chromadb.Client() |
|
|
collection = client_db.get_or_create_collection("news_articles") |
|
|
|
|
|
|
|
|
try: |
|
|
collection.delete(ids=[str(i) for i in range(len(df))]) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
for idx, row in df.iterrows(): |
|
|
try: |
|
|
collection.add( |
|
|
documents=[row["article"]], |
|
|
embeddings=[row["embedding"]], |
|
|
ids=[str(idx)] |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error adding document {idx}: {e}") |
|
|
|
|
|
|
|
|
@retry(tries=3, delay=2, backoff=2) |
|
|
def search_articles(query, top_k=3): |
|
|
try: |
|
|
query_embedding = get_embedding(query) |
|
|
results = collection.query(query_embeddings=[query_embedding], n_results=top_k) |
|
|
if not results["ids"][0]: |
|
|
return [] |
|
|
indices = [int(id) for id in results["ids"][0]] |
|
|
return df.iloc[indices]["article"].tolist() |
|
|
except Exception as e: |
|
|
print(f"Search error: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
generation_model = genai.GenerativeModel("gemini-1.5-pro") |
|
|
|
|
|
@retry(tries=3, delay=2, backoff=2) |
|
|
def generate_response(query, articles, system_message): |
|
|
if not articles: |
|
|
return "No relevant articles found.", json.dumps({"error": "No relevant articles found."}) |
|
|
|
|
|
context = "\n".join(articles) |
|
|
prompt = f""" |
|
|
{system_message} |
|
|
Based on the following articles, provide a concise summary (under 100 words) and a structured JSON response with 'question', 'answer', and 'source'. Use only the provided context. |
|
|
|
|
|
Articles: |
|
|
{context} |
|
|
|
|
|
Query: {query} |
|
|
|
|
|
Response: |
|
|
- Summary: |
|
|
- JSON: |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = generation_model.generate_content( |
|
|
prompt, |
|
|
generation_config={ |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.95, |
|
|
"max_output_tokens": 1024, |
|
|
}, |
|
|
stream=False |
|
|
) |
|
|
|
|
|
full_text = response.text |
|
|
|
|
|
|
|
|
summary = "Summary not generated." |
|
|
if "- Summary:" in full_text: |
|
|
summary_start = full_text.find("- Summary:") + len("- Summary:") |
|
|
summary_end = full_text.find("- JSON:", summary_start) |
|
|
if summary_end > summary_start: |
|
|
summary = full_text[summary_start:summary_end].strip() |
|
|
|
|
|
qa_json = "{}" |
|
|
if "- JSON:" in full_text: |
|
|
json_start = full_text.find("- JSON:") + len("- JSON:") |
|
|
qa_json_text = full_text[json_start:].strip() |
|
|
|
|
|
qa_json_text = qa_json_text.replace("``````", "").strip() |
|
|
|
|
|
try: |
|
|
qa = json.loads(qa_json_text) |
|
|
qa_json = json.dumps(qa, indent=2) |
|
|
except json.JSONDecodeError: |
|
|
qa_json = json.dumps({"error": "Failed to parse JSON response.", "raw_text": qa_json_text}) |
|
|
|
|
|
return summary, qa_json |
|
|
except Exception as e: |
|
|
print(f"RAG error: {e}") |
|
|
return f"Error generating response: {str(e)}", json.dumps({"error": f"Failed to generate response: {str(e)}"}) |
|
|
|
|
|
def respond(message, history, system_message="You are a news summarizer and Q&A assistant.", max_tokens=512, temperature=0.7, top_p=0.95): |
|
|
articles = search_articles(message) |
|
|
summary, qa = generate_response(message, articles, system_message) |
|
|
|
|
|
|
|
|
articles_text = "\n".join([f"- {article}" for article in articles]) if articles else "None found" |
|
|
|
|
|
response = ( |
|
|
"**Relevant Articles:**\n" |
|
|
f"{articles_text}\n\n" |
|
|
"**Summary:**\n" |
|
|
f"{summary}\n\n" |
|
|
"**Structured Q&A:**\n" |
|
|
f"{qa}" |
|
|
) |
|
|
yield response |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
additional_inputs=[ |
|
|
gr.Textbox(value="You are a news summarizer and Q&A assistant.", label="System message"), |
|
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), |
|
|
], |
|
|
title="Semantic News Summarizer and Q&A Chatbot", |
|
|
description="Ask about climate change (e.g., 'What are the latest impacts?') to get articles, a summary, and a structured JSON response. Built for the 5-day Gen AI Intensive Course with Google (March 31–April 4, 2025)." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|