Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| import pandas as pd | |
| import requests | |
| from google.oauth2 import service_account | |
| from google.cloud import language_v1 | |
| import tempfile | |
| from collections import defaultdict | |
| # --- Configuration & Authentication --- | |
| # (This section remains unchanged) | |
| api_key_json_string = os.environ.get("GOOGLE_API_KEY_JSON") | |
| language_client = None | |
| auth_success = False | |
| auth_error_message = "" | |
| if not api_key_json_string: | |
| auth_error_message = "ERROR: The `GOOGLE_API_KEY_JSON` secret is not set..." | |
| else: | |
| try: | |
| credentials_info = json.loads(api_key_json_string) | |
| credentials = service_account.Credentials.from_service_account_info(credentials_info) | |
| language_client = language_v1.LanguageServiceClient(credentials=credentials) | |
| auth_success = True | |
| except Exception as e: | |
| auth_error_message = f"ERROR: Failed to authenticate with Google Cloud: {e}" | |
| # --- Color Mapping for Highlighting --- | |
| # Define a consistent color for each entity type for the highlighted text | |
| ENTITY_TYPE_COLORS = { | |
| "PERSON": "#ffc107", # Amber | |
| "LOCATION": "#f44336", # Red | |
| "ORGANIZATION": "#2196f3", # Blue | |
| "EVENT": "#4caf50", # Green | |
| "WORK_OF_ART": "#9c27b0", # Purple | |
| "CONSUMER_GOOD": "#ff9800", # Orange | |
| "OTHER": "#9e9e9e", # Grey | |
| "PHONE_NUMBER": "#795548", # Brown | |
| "ADDRESS": "#607d8b", # Blue Grey | |
| "DATE": "#00bcd4", # Cyan | |
| "NUMBER": "#cddc39", # Lime | |
| "PRICE": "#e91e63" # Pink | |
| } | |
| # --- Helper Functions --- | |
| def find_wikidata_entity(entity_name: str): | |
| """Searches the Wikidata API for an entity and returns its URL if found.""" | |
| try: | |
| params = {"action": "wbsearchentities", "format": "json", "language": "en", "search": entity_name} | |
| response = requests.get("https://www.wikidata.org/w/api.php", params=params, timeout=3) | |
| response.raise_for_status() | |
| data = response.json() | |
| if data.get("search"): | |
| return f"https://www.wikidata.org/wiki/{data['search'][0]['id']}" | |
| except (requests.RequestException, KeyError, IndexError): | |
| return None | |
| return None | |
| def create_entity_dataframe(entities): | |
| """Converts a list of processed entity objects into a pandas DataFrame for display.""" | |
| if not entities: | |
| return pd.DataFrame() | |
| display_data = [] | |
| for entity in entities: | |
| display_data.append({ | |
| "Entity": entity['name'], | |
| "Salience": f"{entity['salience']:.3f}", | |
| "Google KG": f"[Search]({entity['google_kg_url']})" if entity['google_kg_url'] else "N/A", | |
| "Wikipedia": f"[Link]({entity['wikipedia_url']})" if entity['wikipedia_url'] else "N/A", | |
| "Wikidata": f"[Link]({entity['wikidata_url']})" if entity['wikidata_url'] else "N/A", | |
| }) | |
| return pd.DataFrame(display_data) | |
| def format_text_for_highlighting(text, entities): | |
| """ | |
| Processes the original text and entities to create the data structure | |
| needed for Gradio's HighlightedText component. | |
| """ | |
| mentions = [] | |
| for entity in entities: | |
| entity_type_name = language_v1.Entity.Type(entity.type_).name | |
| for mention in entity.mentions: | |
| mentions.append({ | |
| "start": mention.text.begin_offset, | |
| "end": mention.text.begin_offset + len(mention.text.content), | |
| "text": mention.text.content, | |
| "label": entity_type_name | |
| }) | |
| mentions.sort(key=lambda x: x['start']) | |
| highlighted_data = [] | |
| last_index = 0 | |
| for mention in mentions: | |
| if mention['start'] > last_index: | |
| highlighted_data.append((text[last_index:mention['start']], None)) | |
| highlighted_data.append((mention['text'], mention['label'])) | |
| last_index = mention['end'] | |
| if last_index < len(text): | |
| highlighted_data.append((text[last_index:], None)) | |
| return highlighted_data | |
| # --- Core Logic Functions --- | |
| def analyze_text(text_input: str): | |
| """ | |
| Main function to perform analysis. It calls the NLP API, processes results, | |
| and returns all necessary data for the UI, including visibility updates. | |
| """ | |
| if not auth_success: | |
| raise gr.Error(auth_error_message) | |
| if not text_input or not text_input.strip(): | |
| return "Please enter text to analyze.", {}, None, gr.update(visible=False) | |
| try: | |
| document = language_v1.Document(content=text_input, type_=language_v1.Document.Type.PLAIN_TEXT) | |
| encoding_type = language_v1.EncodingType.UTF8 | |
| response = language_client.analyze_entities(request={"document": document, "encoding_type": encoding_type}) | |
| highlight_data = format_text_for_highlighting(text_input, response.entities) | |
| all_entities = defaultdict(list) | |
| for entity in response.entities: | |
| google_kg_url = f"https://www.google.com/search?kgmid={entity.metadata['mid']}" if 'mid' in entity.metadata else None | |
| wikipedia_url = entity.metadata.get("wikipedia_url", None) | |
| wikidata_url = find_wikidata_entity(entity.name) | |
| link_count = sum(1 for link in [google_kg_url, wikipedia_url, wikidata_url] if link) | |
| entity_type_name = language_v1.Entity.Type(entity.type_).name | |
| all_entities[entity_type_name].append({ | |
| 'name': entity.name, 'type': entity_type_name, 'salience': entity.salience, | |
| 'google_kg_url': google_kg_url, 'wikipedia_url': wikipedia_url, | |
| 'wikidata_url': wikidata_url, 'link_count': link_count | |
| }) | |
| summary = f"Analysis complete. Found {len(response.entities)} total entities across {len(all_entities)} types." | |
| # Return an update to make the results section visible | |
| return summary, all_entities, highlight_data, gr.update(visible=True) | |
| except Exception as e: | |
| raise gr.Error(f"An error occurred during API call: {e}") | |
| def sort_and_update_ui(all_entities, prioritize_identified): | |
| """ | |
| Takes the full entity data and a boolean toggle, then returns a list of | |
| UI updates for all the accordions and dataframes. | |
| """ | |
| display_order = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT", "WORK_OF_ART", "CONSUMER_GOOD", "OTHER", "PHONE_NUMBER", "ADDRESS", "DATE", "NUMBER", "PRICE"] | |
| updates = [] | |
| for entity_type in display_order: | |
| entities_of_type = all_entities.get(entity_type, []) | |
| if prioritize_identified: | |
| entities_of_type.sort(key=lambda x: (x['link_count'], x['salience']), reverse=True) | |
| else: | |
| entities_of_type.sort(key=lambda x: x['salience'], reverse=True) | |
| df = create_entity_dataframe(entities_of_type) | |
| is_visible = len(entities_of_type) > 0 | |
| accordion_label = f"{entity_type.replace('_', ' ')} ({len(entities_of_type)} entities)" | |
| updates.append(gr.Accordion(label=accordion_label, visible=is_visible)) | |
| updates.append(df) | |
| return updates | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# Advanced Entity Analyzer") | |
| analysis_results_state = gr.State({}) | |
| with gr.Row(): | |
| text_input = gr.Textbox(lines=20, label="Enter Text for Analysis", placeholder="Paste your article content here...") | |
| analyze_button = gr.Button("Analyze Text", variant="primary") | |
| # Use gr.Column to group results, and make it initially invisible | |
| with gr.Column(visible=False) as results_section: | |
| gr.Markdown("### Analysis Results") | |
| summary_output = gr.Textbox(label="Summary", interactive=False) | |
| with gr.Tabs(): | |
| with gr.TabItem("Visual Analysis"): | |
| highlighted_text_output = gr.HighlightedText( | |
| label="Highlighted Entities", | |
| color_map=ENTITY_TYPE_COLORS, | |
| show_legend=True | |
| ) | |
| with gr.TabItem("Detailed Breakdown"): | |
| prioritize_checkbox = gr.Checkbox(label="Prioritize identified entities (with links)", value=False) | |
| accordions = {} | |
| dataframes = {} | |
| all_types = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT", "WORK_OF_ART", "CONSUMER_GOOD", "OTHER", "PHONE_NUMBER", "ADDRESS", "DATE", "NUMBER", "PRICE"] | |
| for entity_type in all_types: | |
| with gr.Accordion(f"{entity_type.replace('_', ' ')}", visible=False) as acc: | |
| accordions[entity_type] = acc | |
| df = gr.Dataframe( | |
| headers=["Entity", "Salience", "Google KG", "Wikipedia", "Wikidata"], | |
| datatype=["str", "str", "markdown", "markdown", "markdown"], | |
| wrap=True | |
| ) | |
| dataframes[entity_type] = df | |
| all_detailed_outputs = [item for pair in zip(accordions.values(), dataframes.values()) for item in pair] | |
| analyze_button.click( | |
| fn=analyze_text, | |
| inputs=[text_input], | |
| outputs=[summary_output, analysis_results_state, highlighted_text_output, results_section] | |
| ).then( | |
| fn=sort_and_update_ui, | |
| inputs=[analysis_results_state, prioritize_checkbox], | |
| outputs=all_detailed_outputs | |
| ) | |
| prioritize_checkbox.change( | |
| fn=sort_and_update_ui, | |
| inputs=[analysis_results_state, prioritize_checkbox], | |
| outputs=all_detailed_outputs | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |