File size: 8,390 Bytes
1b0659e
 
 
 
 
 
 
 
 
 
 
1885bb3
1b0659e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32c52f2
1b0659e
 
 
 
 
 
 
 
 
 
 
 
77bc720
1b0659e
42eee89
1b0659e
 
 
 
 
 
 
 
 
 
77bc720
1b0659e
119e15a
1b0659e
77bc720
 
 
 
42eee89
77bc720
42eee89
1b0659e
119e15a
1b0659e
77bc720
 
1b0659e
119e15a
42eee89
1b0659e
119e15a
1b0659e
 
42eee89
1b0659e
 
 
77bc720
119e15a
42eee89
 
77bc720
c7e6f79
1b0659e
119e15a
1b0659e
c7e6f79
77bc720
1b0659e
119e15a
77bc720
c7e6f79
 
1885bb3
77bc720
1b0659e
 
 
 
 
 
42eee89
1b0659e
42eee89
1b0659e
119e15a
1b0659e
 
 
 
 
42eee89
1b0659e
 
119e15a
1b0659e
119e15a
 
1b0659e
119e15a
1b0659e
119e15a
1b0659e
 
 
 
 
119e15a
1b0659e
 
 
 
 
 
 
1885bb3
1b0659e
42eee89
1b0659e
6c5dad7
1b0659e
 
42eee89
1b0659e
42eee89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b0659e
 
42eee89
1b0659e
42eee89
 
6c5dad7
42eee89
6c5dad7
 
1b0659e
 
 
 
 
 
 
 
 
42eee89
6c5dad7
1b0659e
 
 
 
119e15a
1b0659e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import json
import re
import gradio as gr
from typing import List, TypedDict, Annotated
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

# ==========================================================
# 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
# ==========================================================
with open("gcmd_hierarchy.json", "r") as f:
    gcmd_data = json.load(f)

def build_gcmd_indices(gcmd_json):
    topic_list = []
    sub_tree_indices = {}
    topics = gcmd_json.get("children", [])
    
    for topic_node in topics:
        topic_name = topic_node.get("name", "").upper()
        topic_list.append(topic_name)
        collected_paths = []
        
        def recurse_sub_tree(node, current_path=""):
            node_name = node.get("name", "")
            node_path = f"{current_path} > {node_name}" if current_path else node_name
            if "Variable" in node.get("level", "") or not node.get("children"):
                collected_paths.append(node_path)
            for child in node.get("children", []):
                recurse_sub_tree(child, node_path)

        for term_node in topic_node.get("children", []):
            recurse_sub_tree(term_node, current_path="")
            
        sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
    return topic_list, sub_tree_indices

VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)

# ==========================================================
# 2. LANGGRAPH WORKFLOW WITH DUAL STRUCTURED OUTPUTS
# ==========================================================

def merge_lists(left: list, right: list) -> list:
    return list(set((left or []) + (right or [])))

class MultiTopicState(TypedDict):
    title: str
    abstract: str
    chosen_topics: List[str]
    predicted_keywords: Annotated[List[str], merge_lists]
    invalid_keywords: Annotated[List[str], merge_lists]

# Pydantic schema for Step 1
class TopicsChoice(BaseModel):
    topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")

# New Pydantic schema for Step 2 (Guarantees zero mumbling while preserving full recall)
class KeywordExtraction(BaseModel):
    keywords: List[str] = Field(description="List of exact matching keyword pathways from the provided text block. Return an empty list if nothing matches.")

def route_multi_topic(state: MultiTopicState):
    """Step 1: Identify ALL relevant high-level topics (Restored to be healthily inclusive)."""
    llm = ChatOpenAI(model="gpt-4o", temperature=0)
    structured_llm = llm.with_structured_output(TopicsChoice)
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"),
        ("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.")
    ])
    
    result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
    valid_selected = [t.upper().strip() for t in result.topics if t.upper().strip() in VALID_TOPICS]
    
    if not valid_selected:
        valid_selected = ["FALLBACK"]
        
    return {"chosen_topics": valid_selected}

def classify_individual_topic(topic_name: str):
    """A dynamic factory function using Pydantic tracking to guarantee high recall with zero text pollution."""
    
    def node_runner(state: MultiTopicState):
        llm = ChatOpenAI(model="gpt-4o", temperature=0)
        # Force strict structured output format
        structured_llm = llm.with_structured_output(KeywordExtraction)
        target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
        
        prompt = ChatPromptTemplate.from_messages([
            ("system", f"You are a specialist in {topic_name} data mapping. Extract all exact matching keyword pathways present in the provided valid path list. Do not modify or truncate the pathway strings."),
            ("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nExtract matching pathways as a structured array list.")
        ])
        
        # Enforce JSON list output natively
        result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
        raw_keywords = [k.strip() for k in result.keywords if k.strip()]
        
        # Immediate validation pass inside the node branch
        valid_set = set()
        for line in target_sub_tree.split("\n"):
            if line.strip():
                path = line.replace("- ", "").strip()
                valid_set.add(path)
                
        valid_kws = [kw for kw in raw_keywords if kw in valid_set]
        invalid_kws = [kw for kw in raw_keywords if kw not in valid_set]
        
        return {"predicted_keywords": valid_kws, "invalid_keywords": invalid_kws}
        
    return node_runner

def parallel_router(state: MultiTopicState) -> List[str]:
    return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]

# Assemble the Graph
workflow = StateGraph(MultiTopicState)
workflow.add_node("top_router", route_multi_topic)

for topic in VALID_TOPICS:
    node_id = f"classify_{topic.lower()}"
    workflow.add_node(node_id, classify_individual_topic(topic))

workflow.add_node("classify_fallback", lambda state: {"predicted_keywords": []})
workflow.add_edge(START, "top_router")

workflow.add_conditional_edges(
    "top_router",
    parallel_router,
    {f"classify_{t.lower()}": f"classify_{t.lower()}" for t in VALID_TOPICS} | {"classify_fallback": "classify_fallback"}
)

for topic in VALID_TOPICS:
    workflow.add_edge(f"classify_{topic.lower()}", END)
workflow.add_edge("classify_fallback", END)

app = workflow.compile()

# ==========================================================
# 3. GRADIO USER INTERFACE
# ==========================================================
def run_agent_classifier(title, abstract):
    if not title or not abstract:
        return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
        
    inputs = {"title": title, "abstract": abstract}
    output = app.invoke(inputs)
    
    chosen_topics = output.get("chosen_topics", [])
    predicted_kws = output.get("predicted_keywords", [])
    invalid_kws = output.get("invalid_keywords", [])
    
    topics_str = ", ".join(chosen_topics)
    
    formatted_keywords = []
    for kw in predicted_kws:
        matched_topic = "UNKNOWN"
        for topic in chosen_topics:
            sub_tree_text = SUB_TREE_LOOKUP.get(topic, "")
            if kw in sub_tree_text:
                matched_topic = topic
                break
        formatted_keywords.append(f"• {matched_topic} > {kw}")
        
    if not formatted_keywords:
        keywords_str = "No explicit keywords mapped."
    else:
        keywords_str = "\n".join(formatted_keywords)
        
    if not invalid_kws:
        invalid_str = "None! The agent validation pass achieved 100% data integrity."
    else:
        invalid_str = "\n".join([f"⚠ Caught & Removed: {ikw}" for ikw in invalid_kws])
        
    return topics_str, keywords_str, invalid_str

demo = gr.Interface(
    fn=run_agent_classifier,
    inputs=[
        gr.Textbox(label="Journal Article Title", placeholder="Enter article title here...", lines=1),
        gr.Textbox(label="Abstract / Body Text", placeholder="Paste abstract description here...", lines=5)
    ],
    outputs=[
        gr.Textbox(label="Routed Multi-Topic Domains"),
        gr.Textbox(label="Verified GCMD Keywords Extracted (Formatted with Topics)"),
        gr.Textbox(label="Hallucinated/Invalid Keywords Caught and Removed")
    ],
    title="GCMD Science Keyword Classifier Agent",
    description="Proof of Concept using LangGraph and LangChain. Routes articles concurrently across science domains and runs isolated self-validation routines.",
    examples=[
        ["El Niño Southern Oscillation Driving Anomalous Atmospheric Evaporation", "We observe how rising sea surface temperatures across equatorial waters directly trigger increased low-level cloud formation and accelerated surface winds."]
    ]
)

if __name__ == "__main__":
    demo.launch()