| | import gradio as gr |
| | from typing import Dict, List, Any |
| | import pandas as pd |
| | import json |
| | import re |
| | import html as html_lib |
| | from tasks.ner import named_entity_recognition |
| | from utils.ner_helpers import NER_ENTITY_TYPES, DEFAULT_SELECTED_ENTITIES, is_llm_model |
| |
|
| | |
| |
|
| | def ner_ui(): |
| | |
| | DEFAULT_ENTITY_TYPES = list(NER_ENTITY_TYPES.keys()) |
| | |
| | def ner(text: str, model: str, entity_types: List[str]) -> Dict[str, Any]: |
| | """Extract named entities, automatically using LLM for supported models.""" |
| | if not text.strip(): |
| | return {"text": "", "entities": []} |
| | |
| | try: |
| | use_llm = is_llm_model(model) |
| | |
| | entities = named_entity_recognition( |
| | text=text, |
| | model=model, |
| | use_llm=use_llm, |
| | entity_types=entity_types if use_llm else None |
| | ) |
| | |
| | |
| | if not isinstance(entities, list): |
| | entities = [] |
| | |
| | if not use_llm and entity_types: |
| | entities = [e for e in entities if e.get("type", "") in entity_types or e.get("entity", "") in entity_types] |
| | |
| | return { |
| | "entities": [ |
| | { |
| | "entity": e.get("type", ""), |
| | "word": e.get("text", ""), |
| | "start": e.get("start", 0), |
| | "end": e.get("end", 0), |
| | "score": e.get("confidence", 1.0), |
| | "description": e.get("description", "") |
| | } |
| | for e in entities |
| | ] |
| | } |
| | |
| | except Exception as e: |
| | print(f"Error in NER: {str(e)}") |
| | return {"entities": []} |
| | |
| | def render_ner_html(text, entities): |
| | |
| | if not text.strip() or not entities: |
| | return "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>" |
| | |
| | COLORS = [ |
| | '#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa', |
| | '#f1f8e9', '#fce4ec', '#e8f5e9', '#f5f5f5', '#fafafa', '#e1f5fe', |
| | '#fff3e0', '#d7ccc8', '#f9fbe7', '#fbe9e7', '#ede7f6', '#e0f2f1' |
| | ] |
| | |
| | |
| | clean_entities = [] |
| | label_colors = {} |
| | |
| | for ent in entities: |
| | |
| | label = ent.get('type') or ent.get('entity') |
| | if not label: |
| | continue |
| | |
| | |
| | entity_text = ent.get('text') or ent.get('word') |
| | if not entity_text: |
| | continue |
| | |
| | |
| | start = ent.get('start', -1) |
| | end = ent.get('end', -1) |
| | |
| | |
| | |
| | if start >= 0 and end > start and end <= len(text): |
| | span_text = text[start:end] |
| | if entity_text != span_text and not text[start:end].strip().startswith(entity_text): |
| | |
| | found = False |
| | for match in re.finditer(re.escape(entity_text), text): |
| | if not found: |
| | start = match.start() |
| | end = match.end() |
| | found = True |
| | else: |
| | |
| | found = False |
| | for match in re.finditer(re.escape(entity_text), text): |
| | if not found: |
| | start = match.start() |
| | end = match.end() |
| | found = True |
| | |
| | |
| | if label not in label_colors: |
| | label_colors[label] = COLORS[len(label_colors) % len(COLORS)] |
| | |
| | clean_entities.append({ |
| | 'text': entity_text, |
| | 'label': label, |
| | 'color': label_colors[label], |
| | 'start': start, |
| | 'end': end |
| | }) |
| | |
| | |
| | clean_entities.sort(key=lambda x: x['start']) |
| | |
| | |
| | non_overlapping = [] |
| | if clean_entities: |
| | non_overlapping.append(clean_entities[0]) |
| | for i in range(1, len(clean_entities)): |
| | current = clean_entities[i] |
| | prev = non_overlapping[-1] |
| | |
| | |
| | if current['start'] < prev['end']: |
| | |
| | continue |
| | else: |
| | non_overlapping.append(current) |
| | |
| | |
| | html = ["<div class='ner-highlight' style='line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;'>"] |
| | |
| | |
| | last_pos = 0 |
| | for entity in non_overlapping: |
| | start = entity['start'] |
| | end = entity['end'] |
| | |
| | |
| | if start > last_pos: |
| | html.append(html_lib.escape(text[last_pos:start])) |
| | |
| | |
| | html.append(f"<span style='background:{entity['color']};border-radius:3px;padding:2px 4px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);'>") |
| | html.append(f"{html_lib.escape(entity['text'])} ") |
| | html.append(f"<span style='font-size:0.8em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);'>{html_lib.escape(entity['label'])}</span>") |
| | html.append("</span>") |
| | |
| | |
| | last_pos = end |
| | |
| | |
| | if last_pos < len(text): |
| | html.append(html_lib.escape(text[last_pos:])) |
| | |
| | html.append("</div>") |
| | return "".join(html) |
| | |
| | def update_ui(model_id: str) -> Dict: |
| | """Update the UI based on the selected model.""" |
| | use_llm = is_llm_model(model_id) |
| | return { |
| | entity_types_group: gr.Group(visible=use_llm) |
| | } |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | input_text = gr.Textbox( |
| | label="Input Text", |
| | lines=8, |
| | placeholder="Enter text to analyze for named entities..." |
| | ) |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["Barack Obama was born in Hawaii and became the 44th President of the United States."], |
| | ["Google is headquartered in Mountain View, California."] |
| | ], |
| | inputs=[input_text], |
| | label="Examples" |
| | ) |
| | model_dropdown = gr.Dropdown( |
| | ["gemini-2.0-flash"], |
| | value="gemini-2.0-flash", |
| | label="Model" |
| | ) |
| | |
| | with gr.Group() as entity_types_group: |
| | entity_types = gr.CheckboxGroup( |
| | label="Entity Types to Extract", |
| | choices=DEFAULT_ENTITY_TYPES, |
| | value=DEFAULT_SELECTED_ENTITIES, |
| | interactive=True |
| | ) |
| | with gr.Row(): |
| | select_all_btn = gr.Button("Select All", size="sm") |
| | clear_all_btn = gr.Button("Clear All", size="sm") |
| | |
| | btn = gr.Button("Extract Entities", variant="primary") |
| | |
| | |
| | def select_all_entities(): |
| | return gr.CheckboxGroup(value=DEFAULT_ENTITY_TYPES) |
| | |
| | def clear_all_entities(): |
| | return gr.CheckboxGroup(value=[]) |
| | |
| | select_all_btn.click( |
| | fn=select_all_entities, |
| | outputs=[entity_types] |
| | ) |
| | |
| | clear_all_btn.click( |
| | fn=clear_all_entities, |
| | outputs=[entity_types] |
| | ) |
| | |
| | with gr.Column(scale=3): |
| | |
| | with gr.Tabs() as output_tabs: |
| | with gr.Tab("Tagged View", id="tagged-view-ner"): |
| | no_results_html = gr.HTML( |
| | "<div style='text-align: center; color: #666; padding: 20px;'>" |
| | "Enter text and click 'Extract Entities' to get results.</div>", |
| | visible=True |
| | ) |
| | output_html = gr.HTML( |
| | label="NER Highlighted", |
| | elem_id="ner-output-html", |
| | visible=False |
| | ) |
| | |
| | gr.HTML(""" |
| | <style> |
| | #ner-output-html .pos-highlight { |
| | white-space: pre-wrap; |
| | line-height: 1.8; |
| | font-size: 14px; |
| | padding: 15px; |
| | border: 1px solid #e0e0e0; |
| | border-radius: 4px; |
| | background: #f9f9f9; |
| | } |
| | #ner-output-html .pos-token { |
| | display: inline-block; |
| | margin: 0 2px 4px 0; |
| | vertical-align: top; |
| | text-align: center; |
| | } |
| | #ner-output-html .token-text { |
| | display: block; |
| | padding: 2px 8px; |
| | background: #f0f4f8; |
| | border-radius: 4px 4px 0 0; |
| | border: 1px solid #dbe4ed; |
| | border-bottom: none; |
| | font-size: 0.9em; |
| | } |
| | #ner-output-html .pos-tag { |
| | display: block; |
| | padding: 2px 8px; |
| | border-radius: 0 0 4px 4px; |
| | #ner-output-html .WORK_OF_ART { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; } |
| | #ner-output-html .LAW { background-color: #fce4ec; border-color: #f8bbd0; color: #880e4f; } |
| | #ner-output-html .LANGUAGE { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; font-weight: bold; } |
| | #ner-output-html .DATE { background-color: #f5f5f5; border-color: #e0e0e0; color: #424242; } |
| | #ner-output-html .TIME { background-color: #fafafa; border-color: #f5f5f5; color: #616161; } |
| | #ner-output-html .PERCENT { background-color: #e1f5fe; border-color: #b3e5fc; color: #01579b; font-weight: bold; } |
| | #ner-output-html .MONEY { background-color: #f3e5f5; border-color: #e1bee7; color: #6a1b9a; } |
| | #ner-output-html .QUANTITY { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; font-style: italic; } |
| | #ner-output-html .ORDINAL { background-color: #fff3e0; border-color: #ffe0b2; color: #e65100; } |
| | #ner-output-html .CARDINAL { background-color: #ede7f6; border-color: #d1c4e9; color: #4527a0; } |
| | </style> |
| | """) |
| | with gr.Tab("Table View", id="table-view-ner"): |
| | no_results_table = gr.HTML( |
| | "<div style='text-align: center; color: #666; padding: 20px;'>" |
| | "Enter text and click 'Extract Entities' to get results.</div>", |
| | visible=True |
| | ) |
| | output_table = gr.Dataframe( |
| | label="Extracted Entities", |
| | headers=["Type", "Text", "Confidence", "Description"], |
| | datatype=["str", "str", "number", "str"], |
| | interactive=False, |
| | wrap=True, |
| | visible=False |
| | ) |
| | |
| | |
| | model_dropdown.change( |
| | fn=update_ui, |
| | inputs=[model_dropdown], |
| | outputs=[entity_types_group] |
| | ) |
| | |
| | def process_and_show_results(text: str, model: str, entity_types: List[str]): |
| | """Process NER and return both the results and UI state""" |
| | if not text.strip(): |
| | msg = "<div style='text-align: center; color: #f44336; padding: 20px;'>Please enter some text to analyze.</div>" |
| | return [ |
| | gr.HTML(visible=False), |
| | gr.HTML(msg, visible=True), |
| | gr.DataFrame(visible=False), |
| | gr.HTML(msg, visible=True) |
| | ] |
| | if not entity_types: |
| | entity_types = list(NER_ENTITY_TYPES.keys()) |
| | result = ner(text, model, entity_types) |
| | entities = result["entities"] if result and "entities" in result else [] |
| | |
| | if entities: |
| | df = pd.DataFrame(entities) |
| | if not df.empty: |
| | df = df.rename(columns={ |
| | "entity": "Type", |
| | "word": "Text", |
| | "score": "Confidence", |
| | "description": "Description" |
| | }) |
| | display_columns = ["Type", "Text", "Confidence", "Description"] |
| | df = df[[col for col in display_columns if col in df.columns]] |
| | if 'start' in df.columns: |
| | df = df.sort_values('start') |
| | html = render_ner_html(text, entities) |
| | return [ |
| | gr.HTML(html, visible=True), |
| | gr.HTML(visible=False), |
| | gr.DataFrame(value=df, visible=True), |
| | gr.HTML(visible=False) |
| | ] |
| | |
| | msg = "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>" |
| | return [ |
| | gr.HTML(msg, visible=True), |
| | gr.HTML(visible=False), |
| | gr.DataFrame(visible=False), |
| | gr.HTML(msg, visible=True) |
| | ] |
| | |
| | |
| | btn.click( |
| | fn=process_and_show_results, |
| | inputs=[input_text, model_dropdown, entity_types], |
| | outputs=[output_html, no_results_html, output_table, no_results_table] |
| | ) |
| | |
| | |
| | update_ui(model_dropdown.value) |
| | |
| | return None |
| |
|