Spaces:
Sleeping
Sleeping
| # π Install dependencies | |
| # Make sure to run this in your environment if you haven't already | |
| # !pip install openai anthropic google-generativeai gradio transformers torch gliner numpy pandas --quiet | |
| # βοΈ Imports | |
| import openai | |
| import anthropic | |
| import google.generativeai as genai | |
| import gradio as gr | |
| from gliner import GLiNER | |
| from collections import defaultdict | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import tempfile | |
| # π§ Supported models and their providers | |
| MODEL_OPTIONS = { | |
| "OpenAI (GPT-4o)": "openai", | |
| "Anthropic (Claude 3 Opus)": "anthropic", | |
| "Google (Gemini 1.5 Pro)": "google" | |
| } | |
| # π§ GLiNER Model Configuration | |
| GLINER_MODEL_NAME = "urchade/gliner_large-v2.1" | |
| # --- Load the model only once at startup --- | |
| try: | |
| print("Loading Extraction AI (GLiNER model)... This may take a moment.") | |
| gliner_model = GLiNER.from_pretrained(GLINER_MODEL_NAME) | |
| print("Extraction AI loaded successfully.") | |
| except Exception as e: | |
| print(f"FATAL ERROR: Could not load GLiNER model. The app will not be able to find entities. Error: {e}") | |
| gliner_model = None | |
| # π§ Prompt for the Conceptual AI to generate a research framework | |
| FRAMEWORK_PROMPT_TEMPLATE = """ | |
| You are an expert research assistant specializing in history. For the provided topic: **"{topic}"**, your task is to generate a conceptual research framework. | |
| **Instructions:** | |
| 1. Identify 4-6 high-level **Conceptual Categories** relevant to analyzing this historical topic (e.g., 'Key Figures', 'Core Ideologies', 'Significant Events'). | |
| 2. For each category, list specific, searchable **Labels** that would appear in a primary or secondary source document. | |
| 3. **Crucial Rule for Labels:** Use concise, singular, and fundamental terms (e.g., use `Treaty` not `Diplomatic Treaties`). Use Title Case (e.g. `Working Class`). | |
| **Output Format:** | |
| Use Markdown. Each category must be a Level 3 Header (###), followed by a comma-separated list of its labels. | |
| ### Example Category: Political Actions | |
| - Petition, Charter, Protest, Rally, Legislation | |
| ### Example Category: Social Groups | |
| - Working Class, Aristocracy, Clergy | |
| """ | |
| # π§ Generator Function (The "Conceptual AI") | |
| def generate_from_prompt(prompt, provider, key_dict): | |
| provider_id = MODEL_OPTIONS.get(provider) | |
| api_key = key_dict.get(f"{provider_id}_key") | |
| if not api_key: raise ValueError(f"API key for {provider} not found.") | |
| if provider_id == "openai": | |
| client = openai.OpenAI(api_key=api_key) | |
| response = client.chat.completions.create(model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.2) | |
| return response.choices[0].message.content.strip() | |
| elif provider_id == "anthropic": | |
| client = anthropic.Anthropic(api_key=api_key) | |
| response = client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[{"role": "user", "content": prompt}]) | |
| return response.content[0].text.strip() | |
| elif provider_id == "google": | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel('gemini-1.5-pro-latest') | |
| response = model.generate_content(prompt) | |
| return response.text.strip() | |
| return "" | |
| # --- UI Definitions --- | |
| STANDARD_LABELS = [ | |
| "Person", "Organisation", "Location", "Country", "City", "State", | |
| "Nationality", "Group", "Date", "Event", "Law", "Legal Document", | |
| "Product", "Facility", "Work of Art", "Language", "Time", "Percentage", | |
| "Money", "Currency", "Quantity", "Ordinal Number", "Cardinal Number" | |
| ] | |
| MAX_CATEGORIES = 8 | |
| with gr.Blocks(title="Historical Text Analyser", css=".prose { word-break: break-word; }") as demo: | |
| gr.Markdown("# Historical Text Analyser") | |
| gr.Markdown(""" | |
| First, a **Conceptual AI**, powered by a generative AI Large Language Model (LLM) such as OpenAI's GPT-4, Anthropic's Claude, or Google's Gemini, suggests labels based on your chosen historical topic. These labels are grouped into broader categories (e.g. Economic Policies, Significant Events) to help focus your research. | |
| Second, an **Extraction AI**, powered by the GLiNER model, scans your source text to find and highlight matching entities - instances where those labels appear in the document - with high accuracy. | |
| ### Understanding Entities and Labels ### | |
| In text analysis, this process is often called Named Entity Recognition (NER). | |
| - An **Entity** is a specific piece of text in your document, such as a name, a place, or a date (e.g. Queen Victoria, 1848). | |
| - A **Label** is the category that the entity belongs to (e.g. Person, Date, Location). | |
| This tool helps you to define your labels and then finds the corresponding entities in your text. | |
| """) | |
| gr.Markdown("--- \n## Step 1: Generate Labels") | |
| with gr.Row(): | |
| topic = gr.Textbox(label="Enter a Historical Topic", placeholder="e.g. Britain during the Second World War") | |
| provider = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose AI Model") | |
| with gr.Row(): | |
| openai_key = gr.Textbox(label="OpenAI API Key", type="password") | |
| anthropic_key = gr.Textbox(label="Anthropic API Key", type="password") | |
| google_key = gr.Textbox(label="Google API Key", type="password") | |
| generate_btn = gr.Button("Generate Labels", variant="primary") | |
| gr.Markdown("--- \n## Step 2: Confirm Labels and Analyse Source Text") | |
| gr.Markdown("#### 1. AI-Suggested Labels") | |
| dynamic_components = [] | |
| with gr.Column(): | |
| for i in range(MAX_CATEGORIES): | |
| with gr.Accordion(f"Suggested Category {i+1}", visible=False) as acc: | |
| cg = gr.CheckboxGroup(label="Labels in this category", interactive=True) | |
| with gr.Row(): | |
| select_btn = gr.Button("Select All", size="sm") | |
| deselect_btn = gr.Button("Deselect All", size="sm") | |
| dynamic_components.append((acc, cg, select_btn, deselect_btn)) | |
| gr.Markdown("#### 2. Standard Labels (Optional)") | |
| with gr.Group(): | |
| standard_labels_checkbox = gr.CheckboxGroup(choices=STANDARD_LABELS, value=STANDARD_LABELS, label="Standard Entity Labels") | |
| with gr.Row(): | |
| select_all_std_btn = gr.Button("Select All", size="sm") | |
| deselect_all_std_btn = gr.Button("Deselect All", size="sm") | |
| gr.Markdown("#### 3. Custom Labels (Optional)") | |
| with gr.Group(): | |
| custom_labels_textbox = gr.Textbox(label="Enter Custom Labels (separate with commas)", placeholder="e.g., Technology, Weapon, Secret Society...") | |
| gr.Markdown("--- \n## Step 3: Run Analysis") | |
| threshold_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold") | |
| text_input = gr.Textbox(label="Paste Your Text Here for Analysis", lines=15) | |
| analyze_btn = gr.Button("Find Entities", variant="primary") | |
| analysis_status = gr.Markdown(visible=False) | |
| gr.Markdown("--- \n## Step 4: Review Results") | |
| gr.Markdown("β¨ **Pro Tip:** In the **\"Highlighted Text\"** view, you can click and drag to highlight text and create your own labels!") | |
| with gr.Tabs(): | |
| with gr.TabItem("Highlighted Text"): | |
| highlighted_text_output = gr.HighlightedText(label="Found Entities", interactive=True) | |
| with gr.TabItem("Detailed Results"): | |
| gr.Markdown("You can sort the table by clicking on column headers. The download link for the full table will appear here after analysis.") | |
| # MOVED and RECONFIGURED: The file download link now sits here permanently. | |
| csv_file_output = gr.File(label="Download Results as CSV", visible=False) | |
| detailed_results_output = gr.DataFrame(headers=["Label", "Text Found", "Instances", "Confidence Score"], datatype=["str", "str", "number", "number"], label="Aggregated List of Found Entities") | |
| with gr.TabItem("Debug Log"): | |
| debug_output = gr.Textbox(label="Extraction Process Log", interactive=False, lines=8) | |
| # --- Backend Functions --- | |
| def handle_generate(topic, provider, openai_k, anthropic_k, google_k): | |
| yield {generate_btn: gr.update(value="Generating...", interactive=False)} | |
| try: | |
| key_dict = {"openai_key": os.environ.get("OPENAI_API_KEY", openai_k), "anthropic_key": os.environ.get("ANTHROPIC_API_KEY", anthropic_k), "google_key": os.environ.get("GOOGLE_API_KEY", google_k)} | |
| provider_id = MODEL_OPTIONS.get(provider) | |
| if not topic or not provider or not key_dict.get(f"{provider_id}_key"): raise gr.Error("A topic, provider, and valid API Key for that provider are required.") | |
| prompt = FRAMEWORK_PROMPT_TEMPLATE.format(topic=topic) | |
| raw_framework = generate_from_prompt(prompt, provider, key_dict) | |
| framework = defaultdict(list) | |
| current_category = None | |
| for line in raw_framework.split('\n'): | |
| line = line.strip() | |
| if line.startswith("###"): current_category = line.replace("###", "").strip() | |
| elif line.startswith("-") and current_category: | |
| entities_string = line.replace("-", "").strip() | |
| framework[current_category].extend([e.strip() for e in entities_string.split(',') if e.strip()]) | |
| if not framework: raise gr.Error("The AI failed to generate categories. Please try again or rephrase your topic.") | |
| updates = {} | |
| categories = list(framework.items()) | |
| for i in range(MAX_CATEGORIES): | |
| acc, cg, sel, desel = dynamic_components[i] | |
| if i < len(categories): | |
| cat_name, entities = categories[i] | |
| sorted_entities = sorted(list(set(entities))) | |
| updates[acc] = gr.update(label=f"Category: {cat_name}", visible=True) | |
| updates[cg] = gr.update(choices=sorted_entities, value=sorted_entities, label="Suggested Labels", visible=True) | |
| updates[sel] = gr.update(visible=True) | |
| updates[desel] = gr.update(visible=True) | |
| else: | |
| updates[acc] = gr.update(visible=False); updates[cg] = gr.update(choices=[], value=[], visible=False); updates[sel] = gr.update(visible=False); updates[desel] = gr.update(visible=False) | |
| updates[generate_btn] = gr.update(value="Generate Labels", interactive=True) | |
| yield updates | |
| except Exception as e: | |
| yield {generate_btn: gr.update(value="Generate Labels", interactive=True)} | |
| raise gr.Error(str(e)) | |
| def analyze_text(text, standard_labels, custom_label_text, threshold, *suggested_labels_from_groups): | |
| yield { | |
| analyze_btn: gr.update(value="Finding Entities...", interactive=False), | |
| analysis_status: gr.update(value="The Extraction AI is scanning your text...", visible=True), | |
| highlighted_text_output: None, detailed_results_output: None, debug_output: "Starting analysis...", | |
| csv_file_output: gr.update(visible=False, value=None) | |
| } | |
| if gliner_model is None: raise gr.Error("Extraction AI (GLiNER model) is not loaded.") | |
| labels_to_use = set() | |
| for group in suggested_labels_from_groups: | |
| if group: labels_to_use.update(group) | |
| if standard_labels: labels_to_use.update(standard_labels) | |
| custom = {l.strip() for l in custom_label_text.split(',') if l.strip()} | |
| if custom: labels_to_use.update(custom) | |
| final_labels = sorted(list(labels_to_use)) | |
| if not text or not final_labels: | |
| yield { | |
| analyze_btn: gr.update(value="Find Entities", interactive=True), | |
| analysis_status: gr.update(visible=False), | |
| highlighted_text_output: {"text": text, "entities": []}, | |
| detailed_results_output: None, | |
| debug_output: "Analysis stopped: No text or no labels provided.", | |
| csv_file_output: gr.update(visible=False, value=None) | |
| } | |
| return | |
| all_entities = [] | |
| chunk_size, overlap = 1024, 100 | |
| for i in range(0, len(text), chunk_size - overlap): | |
| chunk = text[i : i + chunk_size] | |
| chunk_entities = gliner_model.predict_entities(chunk, final_labels, threshold=threshold) | |
| for ent in chunk_entities: | |
| ent['start'] += i; ent['end'] += i | |
| all_entities.append(ent) | |
| unique_entities = [dict(t) for t in {tuple(d.items()) for d in all_entities}] | |
| highlighted_output_data = {"text": text, "entities": [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in unique_entities]} | |
| aggregated_matches = defaultdict(lambda: {'count': 0, 'scores': [], 'original_casing': ''}) | |
| for ent in unique_entities: | |
| match_text = text[ent['start']:ent['end']] | |
| key = (ent['label'], match_text.lower()) | |
| aggregated_matches[key]['count'] += 1 | |
| aggregated_matches[key]['scores'].append(ent['score']) | |
| if not aggregated_matches[key]['original_casing']: aggregated_matches[key]['original_casing'] = match_text | |
| table_rows = [] | |
| for (label, _), data in aggregated_matches.items(): | |
| avg_score = np.mean(data['scores']) | |
| table_rows.append({"Label": label, "Text Found": data['original_casing'], "Instances": data['count'], "Confidence Score": round(avg_score, 2)}) | |
| results_df = pd.DataFrame(table_rows) | |
| if not results_df.empty: results_df = results_df.sort_values(by=["Label", "Instances"], ascending=[True, False]) | |
| # --- REFACTORED: Create and return the CSV file directly --- | |
| csv_file_path = None | |
| if not results_df.empty: | |
| with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv', encoding='utf-8') as tmpfile: | |
| results_df.to_csv(tmpfile.name, index=False) | |
| csv_file_path = tmpfile.name | |
| yield { | |
| analyze_btn: gr.update(value="Find Entities", interactive=True), | |
| analysis_status: gr.update(visible=False), | |
| highlighted_text_output: highlighted_output_data, | |
| detailed_results_output: results_df, | |
| debug_output: "Analysis complete.", | |
| csv_file_output: gr.update(value=csv_file_path, visible=bool(csv_file_path)) | |
| } | |
| # --- Wire up UI events --- | |
| generate_btn.click(fn=handle_generate, inputs=[topic, provider, openai_key, anthropic_key, google_key], outputs=[generate_btn] + [c for p in dynamic_components for c in p]) | |
| def deselect_all(): return gr.update(value=[]) | |
| deselect_all_std_btn.click(fn=deselect_all, inputs=None, outputs=[standard_labels_checkbox]) | |
| select_all_std_btn.click(lambda: gr.update(value=STANDARD_LABELS), inputs=None, outputs=[standard_labels_checkbox]) | |
| for _, cg, sel_btn, desel_btn in dynamic_components: | |
| sel_btn.click(fn=lambda c=cg: gr.update(value=c.choices), inputs=None, outputs=[cg]) | |
| desel_btn.click(fn=deselect_all, inputs=None, outputs=[cg]) | |
| analyze_btn.click( | |
| fn=analyze_text, | |
| inputs=[text_input, standard_labels_checkbox, custom_labels_textbox, threshold_slider] + [cg for acc, cg, sel, desel in dynamic_components], | |
| # The outputs list is now cleaner | |
| outputs=[analyze_btn, analysis_status, highlighted_text_output, detailed_results_output, debug_output, csv_file_output] | |
| ) | |
| demo.launch(share=True, debug=True) |