entity-analysis / app.py
Lakshitha-Vithanage's picture
Create app.py
43ea66b verified
import gradio as gr
import os
import json
import pandas as pd
import requests
from google.oauth2 import service_account
from google.cloud import language_v1
import tempfile
from collections import defaultdict
# --- Configuration & Authentication ---
# (This section remains unchanged)
api_key_json_string = os.environ.get("GOOGLE_API_KEY_JSON")
language_client = None
auth_success = False
auth_error_message = ""
if not api_key_json_string:
auth_error_message = "ERROR: The `GOOGLE_API_KEY_JSON` secret is not set..."
else:
try:
credentials_info = json.loads(api_key_json_string)
credentials = service_account.Credentials.from_service_account_info(credentials_info)
language_client = language_v1.LanguageServiceClient(credentials=credentials)
auth_success = True
except Exception as e:
auth_error_message = f"ERROR: Failed to authenticate with Google Cloud: {e}"
# --- Color Mapping for Highlighting ---
# Define a consistent color for each entity type for the highlighted text
ENTITY_TYPE_COLORS = {
"PERSON": "#ffc107", # Amber
"LOCATION": "#f44336", # Red
"ORGANIZATION": "#2196f3", # Blue
"EVENT": "#4caf50", # Green
"WORK_OF_ART": "#9c27b0", # Purple
"CONSUMER_GOOD": "#ff9800", # Orange
"OTHER": "#9e9e9e", # Grey
"PHONE_NUMBER": "#795548", # Brown
"ADDRESS": "#607d8b", # Blue Grey
"DATE": "#00bcd4", # Cyan
"NUMBER": "#cddc39", # Lime
"PRICE": "#e91e63" # Pink
}
# --- Helper Functions ---
def find_wikidata_entity(entity_name: str):
"""Searches the Wikidata API for an entity and returns its URL if found."""
try:
params = {"action": "wbsearchentities", "format": "json", "language": "en", "search": entity_name}
response = requests.get("https://www.wikidata.org/w/api.php", params=params, timeout=3)
response.raise_for_status()
data = response.json()
if data.get("search"):
return f"https://www.wikidata.org/wiki/{data['search'][0]['id']}"
except (requests.RequestException, KeyError, IndexError):
return None
return None
def create_entity_dataframe(entities):
"""Converts a list of processed entity objects into a pandas DataFrame for display."""
if not entities:
return pd.DataFrame()
display_data = []
for entity in entities:
display_data.append({
"Entity": entity['name'],
"Salience": f"{entity['salience']:.3f}",
"Google KG": f"[Search]({entity['google_kg_url']})" if entity['google_kg_url'] else "N/A",
"Wikipedia": f"[Link]({entity['wikipedia_url']})" if entity['wikipedia_url'] else "N/A",
"Wikidata": f"[Link]({entity['wikidata_url']})" if entity['wikidata_url'] else "N/A",
})
return pd.DataFrame(display_data)
def format_text_for_highlighting(text, entities):
"""
Processes the original text and entities to create the data structure
needed for Gradio's HighlightedText component.
"""
mentions = []
for entity in entities:
entity_type_name = language_v1.Entity.Type(entity.type_).name
for mention in entity.mentions:
mentions.append({
"start": mention.text.begin_offset,
"end": mention.text.begin_offset + len(mention.text.content),
"text": mention.text.content,
"label": entity_type_name
})
mentions.sort(key=lambda x: x['start'])
highlighted_data = []
last_index = 0
for mention in mentions:
if mention['start'] > last_index:
highlighted_data.append((text[last_index:mention['start']], None))
highlighted_data.append((mention['text'], mention['label']))
last_index = mention['end']
if last_index < len(text):
highlighted_data.append((text[last_index:], None))
return highlighted_data
# --- Core Logic Functions ---
def analyze_text(text_input: str):
"""
Main function to perform analysis. It calls the NLP API, processes results,
and returns all necessary data for the UI, including visibility updates.
"""
if not auth_success:
raise gr.Error(auth_error_message)
if not text_input or not text_input.strip():
return "Please enter text to analyze.", {}, None, gr.update(visible=False)
try:
document = language_v1.Document(content=text_input, type_=language_v1.Document.Type.PLAIN_TEXT)
encoding_type = language_v1.EncodingType.UTF8
response = language_client.analyze_entities(request={"document": document, "encoding_type": encoding_type})
highlight_data = format_text_for_highlighting(text_input, response.entities)
all_entities = defaultdict(list)
for entity in response.entities:
google_kg_url = f"https://www.google.com/search?kgmid={entity.metadata['mid']}" if 'mid' in entity.metadata else None
wikipedia_url = entity.metadata.get("wikipedia_url", None)
wikidata_url = find_wikidata_entity(entity.name)
link_count = sum(1 for link in [google_kg_url, wikipedia_url, wikidata_url] if link)
entity_type_name = language_v1.Entity.Type(entity.type_).name
all_entities[entity_type_name].append({
'name': entity.name, 'type': entity_type_name, 'salience': entity.salience,
'google_kg_url': google_kg_url, 'wikipedia_url': wikipedia_url,
'wikidata_url': wikidata_url, 'link_count': link_count
})
summary = f"Analysis complete. Found {len(response.entities)} total entities across {len(all_entities)} types."
# Return an update to make the results section visible
return summary, all_entities, highlight_data, gr.update(visible=True)
except Exception as e:
raise gr.Error(f"An error occurred during API call: {e}")
def sort_and_update_ui(all_entities, prioritize_identified):
"""
Takes the full entity data and a boolean toggle, then returns a list of
UI updates for all the accordions and dataframes.
"""
display_order = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT", "WORK_OF_ART", "CONSUMER_GOOD", "OTHER", "PHONE_NUMBER", "ADDRESS", "DATE", "NUMBER", "PRICE"]
updates = []
for entity_type in display_order:
entities_of_type = all_entities.get(entity_type, [])
if prioritize_identified:
entities_of_type.sort(key=lambda x: (x['link_count'], x['salience']), reverse=True)
else:
entities_of_type.sort(key=lambda x: x['salience'], reverse=True)
df = create_entity_dataframe(entities_of_type)
is_visible = len(entities_of_type) > 0
accordion_label = f"{entity_type.replace('_', ' ')} ({len(entities_of_type)} entities)"
updates.append(gr.Accordion(label=accordion_label, visible=is_visible))
updates.append(df)
return updates
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# Advanced Entity Analyzer")
analysis_results_state = gr.State({})
with gr.Row():
text_input = gr.Textbox(lines=20, label="Enter Text for Analysis", placeholder="Paste your article content here...")
analyze_button = gr.Button("Analyze Text", variant="primary")
# Use gr.Column to group results, and make it initially invisible
with gr.Column(visible=False) as results_section:
gr.Markdown("### Analysis Results")
summary_output = gr.Textbox(label="Summary", interactive=False)
with gr.Tabs():
with gr.TabItem("Visual Analysis"):
highlighted_text_output = gr.HighlightedText(
label="Highlighted Entities",
color_map=ENTITY_TYPE_COLORS,
show_legend=True
)
with gr.TabItem("Detailed Breakdown"):
prioritize_checkbox = gr.Checkbox(label="Prioritize identified entities (with links)", value=False)
accordions = {}
dataframes = {}
all_types = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT", "WORK_OF_ART", "CONSUMER_GOOD", "OTHER", "PHONE_NUMBER", "ADDRESS", "DATE", "NUMBER", "PRICE"]
for entity_type in all_types:
with gr.Accordion(f"{entity_type.replace('_', ' ')}", visible=False) as acc:
accordions[entity_type] = acc
df = gr.Dataframe(
headers=["Entity", "Salience", "Google KG", "Wikipedia", "Wikidata"],
datatype=["str", "str", "markdown", "markdown", "markdown"],
wrap=True
)
dataframes[entity_type] = df
all_detailed_outputs = [item for pair in zip(accordions.values(), dataframes.values()) for item in pair]
analyze_button.click(
fn=analyze_text,
inputs=[text_input],
outputs=[summary_output, analysis_results_state, highlighted_text_output, results_section]
).then(
fn=sort_and_update_ui,
inputs=[analysis_results_state, prioritize_checkbox],
outputs=all_detailed_outputs
)
prioritize_checkbox.change(
fn=sort_and_update_ui,
inputs=[analysis_results_state, prioritize_checkbox],
outputs=all_detailed_outputs
)
if __name__ == "__main__":
iface.launch()