File size: 15,662 Bytes
2f5ff37
 
f9a5a03
2f5ff37
 
 
 
 
 
 
d51705e
5f5b923
1773b6d
80cecba
1773b6d
2f5ff37
 
 
 
 
 
 
 
 
 
 
 
 
5f5b923
2f5ff37
5f5b923
2f5ff37
 
 
 
5f5b923
 
 
e9738aa
5f5b923
 
d51705e
e9738aa
5f5b923
 
 
 
 
e9738aa
 
5f5b923
e9738aa
 
 
d56dd19
e9738aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502c63f
d51705e
502c63f
d51705e
e9738aa
5f5b923
e9738aa
502c63f
 
d56dd19
f7e349e
3406a99
502c63f
 
 
 
3406a99
 
e9738aa
502c63f
2f5ff37
f7e349e
502c63f
2f5ff37
80cecba
 
 
502c63f
e9738aa
502c63f
5f5b923
e9738aa
2f5ff37
 
5f5b923
 
2f5ff37
5f5b923
 
 
e9738aa
5f5b923
2f5ff37
d56dd19
e9738aa
 
 
 
5f5b923
2f5ff37
3406a99
e9738aa
5f5b923
d56dd19
3406a99
d51705e
5f5b923
e9738aa
5f5b923
3406a99
2f5ff37
 
 
e9738aa
2f5ff37
3406a99
27a2304
 
d56dd19
5f5b923
 
2f5ff37
 
35ef54e
2f5ff37
d51705e
2f5ff37
5f5b923
2f5ff37
d56dd19
1773b6d
5f5b923
2f5ff37
1773b6d
2f5ff37
 
 
 
27a2304
1773b6d
 
 
d56dd19
1773b6d
2f5ff37
 
 
d56dd19
2f5ff37
d56dd19
2f5ff37
d56dd19
 
 
 
2f5ff37
27a2304
1773b6d
3406a99
2f5ff37
 
3406a99
2f5ff37
 
5f5b923
80cecba
d51705e
f9a5a03
 
d56dd19
80cecba
 
f9a5a03
2f5ff37
e9738aa
2f5ff37
e9738aa
2f5ff37
 
e9738aa
d56dd19
2f5ff37
e9738aa
d51705e
e9738aa
 
f9a5a03
d56dd19
 
e9738aa
80cecba
e9738aa
2f5ff37
e9738aa
 
2f5ff37
 
 
5f5b923
2f5ff37
 
d56dd19
f9a5a03
e9738aa
 
 
 
 
 
 
f9a5a03
2f5ff37
d51705e
e9738aa
 
d56dd19
d51705e
f9a5a03
d56dd19
e9738aa
27a2304
 
 
 
 
 
 
80cecba
d51705e
e9738aa
 
f9a5a03
d56dd19
27a2304
80cecba
f9a5a03
e9738aa
d56dd19
e9738aa
f9a5a03
e9738aa
d56dd19
80cecba
5f5b923
d51705e
5f5b923
e9738aa
 
5f5b923
 
27a2304
 
f9a5a03
 
2f5ff37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# πŸ“š Install dependencies
# Make sure to run this in your environment if you haven't already
# !pip install openai anthropic google-generativeai gradio transformers torch gliner numpy pandas --quiet

# βš™οΈ Imports
import openai
import anthropic
import google.generativeai as genai
import gradio as gr
from gliner import GLiNER
from collections import defaultdict
import numpy as np
import pandas as pd
import os
import tempfile

# 🧠 Supported models and their providers
MODEL_OPTIONS = {
    "OpenAI (GPT-4o)": "openai",
    "Anthropic (Claude 3 Opus)": "anthropic",
    "Google (Gemini 1.5 Pro)": "google"
}

# πŸ”§ GLiNER Model Configuration
GLINER_MODEL_NAME = "urchade/gliner_large-v2.1"

# --- Load the model only once at startup ---
try:
    print("Loading Extraction AI (GLiNER model)... This may take a moment.")
    gliner_model = GLiNER.from_pretrained(GLINER_MODEL_NAME)
    print("Extraction AI loaded successfully.")
except Exception as e:
    print(f"FATAL ERROR: Could not load GLiNER model. The app will not be able to find entities. Error: {e}")
    gliner_model = None

# 🧠 Prompt for the Conceptual AI to generate a research framework
FRAMEWORK_PROMPT_TEMPLATE = """
You are an expert research assistant specializing in history. For the provided topic: **"{topic}"**, your task is to generate a conceptual research framework.
**Instructions:**
1.  Identify 4-6 high-level **Conceptual Categories** relevant to analyzing this historical topic (e.g., 'Key Figures', 'Core Ideologies', 'Significant Events').
2.  For each category, list specific, searchable **Labels** that would appear in a primary or secondary source document.
3.  **Crucial Rule for Labels:** Use concise, singular, and fundamental terms (e.g., use `Treaty` not `Diplomatic Treaties`). Use Title Case (e.g. `Working Class`).
**Output Format:**
Use Markdown. Each category must be a Level 3 Header (###), followed by a comma-separated list of its labels.
### Example Category: Political Actions
- Petition, Charter, Protest, Rally, Legislation
### Example Category: Social Groups
- Working Class, Aristocracy, Clergy
"""

# 🧠 Generator Function (The "Conceptual AI")
def generate_from_prompt(prompt, provider, key_dict):
    provider_id = MODEL_OPTIONS.get(provider)
    api_key = key_dict.get(f"{provider_id}_key")
    if not api_key: raise ValueError(f"API key for {provider} not found.")
    if provider_id == "openai":
        client = openai.OpenAI(api_key=api_key)
        response = client.chat.completions.create(model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.2)
        return response.choices[0].message.content.strip()
    elif provider_id == "anthropic":
        client = anthropic.Anthropic(api_key=api_key)
        response = client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[{"role": "user", "content": prompt}])
        return response.content[0].text.strip()
    elif provider_id == "google":
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-1.5-pro-latest')
        response = model.generate_content(prompt)
        return response.text.strip()
    return ""

# --- UI Definitions ---
STANDARD_LABELS = [
    "Person", "Organisation", "Location", "Country", "City", "State", 
    "Nationality", "Group", "Date", "Event", "Law", "Legal Document", 
    "Product", "Facility", "Work of Art", "Language", "Time", "Percentage", 
    "Money", "Currency", "Quantity", "Ordinal Number", "Cardinal Number"
]
MAX_CATEGORIES = 8

with gr.Blocks(title="Historical Text Analyser", css=".prose { word-break: break-word; }") as demo:
    gr.Markdown("# Historical Text Analyser")
    gr.Markdown("""
        First, a **Conceptual AI**, powered by a generative AI Large Language Model (LLM) such as OpenAI's GPT-4, Anthropic's Claude, or Google's Gemini, suggests labels based on your chosen historical topic. These labels are grouped into broader categories (e.g. Economic Policies, Significant Events) to help focus your research.
        Second, an **Extraction AI**, powered by the GLiNER model, scans your source text to find and highlight matching entities - instances where those labels appear in the document - with high accuracy.
        ### Understanding Entities and Labels ###
        In text analysis, this process is often called Named Entity Recognition (NER).
        - An **Entity** is a specific piece of text in your document, such as a name, a place, or a date (e.g. Queen Victoria, 1848).
        - A **Label** is the category that the entity belongs to (e.g. Person, Date, Location).
        This tool helps you to define your labels and then finds the corresponding entities in your text.
    """)
    
    gr.Markdown("--- \n## Step 1: Generate Labels")
    with gr.Row():
        topic = gr.Textbox(label="Enter a Historical Topic", placeholder="e.g. Britain during the Second World War")
        provider = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose AI Model")
    with gr.Row():
        openai_key = gr.Textbox(label="OpenAI API Key", type="password")
        anthropic_key = gr.Textbox(label="Anthropic API Key", type="password")
        google_key = gr.Textbox(label="Google API Key", type="password")
    generate_btn = gr.Button("Generate Labels", variant="primary")

    gr.Markdown("--- \n## Step 2: Confirm Labels and Analyse Source Text")
    gr.Markdown("#### 1. AI-Suggested Labels")
    dynamic_components = []
    with gr.Column():
        for i in range(MAX_CATEGORIES):
            with gr.Accordion(f"Suggested Category {i+1}", visible=False) as acc:
                cg = gr.CheckboxGroup(label="Labels in this category", interactive=True)
                with gr.Row():
                    select_btn = gr.Button("Select All", size="sm")
                    deselect_btn = gr.Button("Deselect All", size="sm")
                dynamic_components.append((acc, cg, select_btn, deselect_btn))
    
    gr.Markdown("#### 2. Standard Labels (Optional)")
    with gr.Group():
        standard_labels_checkbox = gr.CheckboxGroup(choices=STANDARD_LABELS, value=STANDARD_LABELS, label="Standard Entity Labels")
        with gr.Row():
            select_all_std_btn = gr.Button("Select All", size="sm")
            deselect_all_std_btn = gr.Button("Deselect All", size="sm")

    gr.Markdown("#### 3. Custom Labels (Optional)")
    with gr.Group():
        custom_labels_textbox = gr.Textbox(label="Enter Custom Labels (separate with commas)", placeholder="e.g., Technology, Weapon, Secret Society...")
    
    gr.Markdown("--- \n## Step 3: Run Analysis")
    threshold_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
    text_input = gr.Textbox(label="Paste Your Text Here for Analysis", lines=15)
    analyze_btn = gr.Button("Find Entities", variant="primary")
    analysis_status = gr.Markdown(visible=False)
    
    gr.Markdown("--- \n## Step 4: Review Results")
    gr.Markdown("✨ **Pro Tip:** In the **\"Highlighted Text\"** view, you can click and drag to highlight text and create your own labels!")
    
    with gr.Tabs():
        with gr.TabItem("Highlighted Text"):
            highlighted_text_output = gr.HighlightedText(label="Found Entities", interactive=True)
        with gr.TabItem("Detailed Results"):
            gr.Markdown("You can sort the table by clicking on column headers. The download link for the full table will appear here after analysis.")
            # MOVED and RECONFIGURED: The file download link now sits here permanently.
            csv_file_output = gr.File(label="Download Results as CSV", visible=False)
            detailed_results_output = gr.DataFrame(headers=["Label", "Text Found", "Instances", "Confidence Score"], datatype=["str", "str", "number", "number"], label="Aggregated List of Found Entities")
        with gr.TabItem("Debug Log"):
            debug_output = gr.Textbox(label="Extraction Process Log", interactive=False, lines=8)
    
    # --- Backend Functions ---

    def handle_generate(topic, provider, openai_k, anthropic_k, google_k):
        yield {generate_btn: gr.update(value="Generating...", interactive=False)}
        try:
            key_dict = {"openai_key": os.environ.get("OPENAI_API_KEY", openai_k), "anthropic_key": os.environ.get("ANTHROPIC_API_KEY", anthropic_k), "google_key": os.environ.get("GOOGLE_API_KEY", google_k)}
            provider_id = MODEL_OPTIONS.get(provider)
            if not topic or not provider or not key_dict.get(f"{provider_id}_key"): raise gr.Error("A topic, provider, and valid API Key for that provider are required.")
            
            prompt = FRAMEWORK_PROMPT_TEMPLATE.format(topic=topic)
            raw_framework = generate_from_prompt(prompt, provider, key_dict)
            
            framework = defaultdict(list)
            current_category = None
            for line in raw_framework.split('\n'):
                line = line.strip()
                if line.startswith("###"): current_category = line.replace("###", "").strip()
                elif line.startswith("-") and current_category:
                    entities_string = line.replace("-", "").strip()
                    framework[current_category].extend([e.strip() for e in entities_string.split(',') if e.strip()])
            if not framework: raise gr.Error("The AI failed to generate categories. Please try again or rephrase your topic.")
            
            updates = {}
            categories = list(framework.items())
            for i in range(MAX_CATEGORIES):
                acc, cg, sel, desel = dynamic_components[i]
                if i < len(categories):
                    cat_name, entities = categories[i]
                    sorted_entities = sorted(list(set(entities)))
                    updates[acc] = gr.update(label=f"Category: {cat_name}", visible=True)
                    updates[cg] = gr.update(choices=sorted_entities, value=sorted_entities, label="Suggested Labels", visible=True)
                    updates[sel] = gr.update(visible=True)
                    updates[desel] = gr.update(visible=True)
                else:
                    updates[acc] = gr.update(visible=False); updates[cg] = gr.update(choices=[], value=[], visible=False); updates[sel] = gr.update(visible=False); updates[desel] = gr.update(visible=False)
            
            updates[generate_btn] = gr.update(value="Generate Labels", interactive=True)
            yield updates
        except Exception as e:
            yield {generate_btn: gr.update(value="Generate Labels", interactive=True)}
            raise gr.Error(str(e))

    def analyze_text(text, standard_labels, custom_label_text, threshold, *suggested_labels_from_groups):
        yield {
            analyze_btn: gr.update(value="Finding Entities...", interactive=False),
            analysis_status: gr.update(value="The Extraction AI is scanning your text...", visible=True),
            highlighted_text_output: None, detailed_results_output: None, debug_output: "Starting analysis...",
            csv_file_output: gr.update(visible=False, value=None)
        }
        
        if gliner_model is None: raise gr.Error("Extraction AI (GLiNER model) is not loaded.")
        labels_to_use = set()
        for group in suggested_labels_from_groups:
            if group: labels_to_use.update(group)
        if standard_labels: labels_to_use.update(standard_labels)
        custom = {l.strip() for l in custom_label_text.split(',') if l.strip()}
        if custom: labels_to_use.update(custom)
        final_labels = sorted(list(labels_to_use))
        
        if not text or not final_labels:
            yield {
                analyze_btn: gr.update(value="Find Entities", interactive=True),
                analysis_status: gr.update(visible=False),
                highlighted_text_output: {"text": text, "entities": []},
                detailed_results_output: None,
                debug_output: "Analysis stopped: No text or no labels provided.",
                csv_file_output: gr.update(visible=False, value=None)
            }
            return

        all_entities = []
        chunk_size, overlap = 1024, 100
        for i in range(0, len(text), chunk_size - overlap):
            chunk = text[i : i + chunk_size]
            chunk_entities = gliner_model.predict_entities(chunk, final_labels, threshold=threshold)
            for ent in chunk_entities:
                ent['start'] += i; ent['end'] += i
                all_entities.append(ent)
        unique_entities = [dict(t) for t in {tuple(d.items()) for d in all_entities}]
        
        highlighted_output_data = {"text": text, "entities": [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in unique_entities]}
        
        aggregated_matches = defaultdict(lambda: {'count': 0, 'scores': [], 'original_casing': ''})
        for ent in unique_entities:
            match_text = text[ent['start']:ent['end']]
            key = (ent['label'], match_text.lower())
            aggregated_matches[key]['count'] += 1
            aggregated_matches[key]['scores'].append(ent['score'])
            if not aggregated_matches[key]['original_casing']: aggregated_matches[key]['original_casing'] = match_text
        
        table_rows = []
        for (label, _), data in aggregated_matches.items():
            avg_score = np.mean(data['scores'])
            table_rows.append({"Label": label, "Text Found": data['original_casing'], "Instances": data['count'], "Confidence Score": round(avg_score, 2)})
            
        results_df = pd.DataFrame(table_rows)
        if not results_df.empty: results_df = results_df.sort_values(by=["Label", "Instances"], ascending=[True, False])
        
        # --- REFACTORED: Create and return the CSV file directly ---
        csv_file_path = None
        if not results_df.empty:
            with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv', encoding='utf-8') as tmpfile:
                results_df.to_csv(tmpfile.name, index=False)
                csv_file_path = tmpfile.name
        
        yield {
            analyze_btn: gr.update(value="Find Entities", interactive=True),
            analysis_status: gr.update(visible=False),
            highlighted_text_output: highlighted_output_data,
            detailed_results_output: results_df, 
            debug_output: "Analysis complete.",
            csv_file_output: gr.update(value=csv_file_path, visible=bool(csv_file_path))
        }
        
    # --- Wire up UI events ---
    generate_btn.click(fn=handle_generate, inputs=[topic, provider, openai_key, anthropic_key, google_key], outputs=[generate_btn] + [c for p in dynamic_components for c in p])
    
    def deselect_all(): return gr.update(value=[])
    deselect_all_std_btn.click(fn=deselect_all, inputs=None, outputs=[standard_labels_checkbox])
    select_all_std_btn.click(lambda: gr.update(value=STANDARD_LABELS), inputs=None, outputs=[standard_labels_checkbox])
    
    for _, cg, sel_btn, desel_btn in dynamic_components:
        sel_btn.click(fn=lambda c=cg: gr.update(value=c.choices), inputs=None, outputs=[cg])
        desel_btn.click(fn=deselect_all, inputs=None, outputs=[cg])

    analyze_btn.click(
        fn=analyze_text,
        inputs=[text_input, standard_labels_checkbox, custom_labels_textbox, threshold_slider] + [cg for acc, cg, sel, desel in dynamic_components],
        # The outputs list is now cleaner
        outputs=[analyze_btn, analysis_status, highlighted_text_output, detailed_results_output, debug_output, csv_file_output]
    )

demo.launch(share=True, debug=True)