File size: 8,390 Bytes
1b0659e 1885bb3 1b0659e 32c52f2 1b0659e 77bc720 1b0659e 42eee89 1b0659e 77bc720 1b0659e 119e15a 1b0659e 77bc720 42eee89 77bc720 42eee89 1b0659e 119e15a 1b0659e 77bc720 1b0659e 119e15a 42eee89 1b0659e 119e15a 1b0659e 42eee89 1b0659e 77bc720 119e15a 42eee89 77bc720 c7e6f79 1b0659e 119e15a 1b0659e c7e6f79 77bc720 1b0659e 119e15a 77bc720 c7e6f79 1885bb3 77bc720 1b0659e 42eee89 1b0659e 42eee89 1b0659e 119e15a 1b0659e 42eee89 1b0659e 119e15a 1b0659e 119e15a 1b0659e 119e15a 1b0659e 119e15a 1b0659e 119e15a 1b0659e 1885bb3 1b0659e 42eee89 1b0659e 6c5dad7 1b0659e 42eee89 1b0659e 42eee89 1b0659e 42eee89 1b0659e 42eee89 6c5dad7 42eee89 6c5dad7 1b0659e 42eee89 6c5dad7 1b0659e 119e15a 1b0659e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | import os
import json
import re
import gradio as gr
from typing import List, TypedDict, Annotated
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# ==========================================================
# 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
# ==========================================================
with open("gcmd_hierarchy.json", "r") as f:
gcmd_data = json.load(f)
def build_gcmd_indices(gcmd_json):
topic_list = []
sub_tree_indices = {}
topics = gcmd_json.get("children", [])
for topic_node in topics:
topic_name = topic_node.get("name", "").upper()
topic_list.append(topic_name)
collected_paths = []
def recurse_sub_tree(node, current_path=""):
node_name = node.get("name", "")
node_path = f"{current_path} > {node_name}" if current_path else node_name
if "Variable" in node.get("level", "") or not node.get("children"):
collected_paths.append(node_path)
for child in node.get("children", []):
recurse_sub_tree(child, node_path)
for term_node in topic_node.get("children", []):
recurse_sub_tree(term_node, current_path="")
sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
return topic_list, sub_tree_indices
VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
# ==========================================================
# 2. LANGGRAPH WORKFLOW WITH DUAL STRUCTURED OUTPUTS
# ==========================================================
def merge_lists(left: list, right: list) -> list:
return list(set((left or []) + (right or [])))
class MultiTopicState(TypedDict):
title: str
abstract: str
chosen_topics: List[str]
predicted_keywords: Annotated[List[str], merge_lists]
invalid_keywords: Annotated[List[str], merge_lists]
# Pydantic schema for Step 1
class TopicsChoice(BaseModel):
topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
# New Pydantic schema for Step 2 (Guarantees zero mumbling while preserving full recall)
class KeywordExtraction(BaseModel):
keywords: List[str] = Field(description="List of exact matching keyword pathways from the provided text block. Return an empty list if nothing matches.")
def route_multi_topic(state: MultiTopicState):
"""Step 1: Identify ALL relevant high-level topics (Restored to be healthily inclusive)."""
llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output(TopicsChoice)
prompt = ChatPromptTemplate.from_messages([
("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"),
("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.")
])
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
valid_selected = [t.upper().strip() for t in result.topics if t.upper().strip() in VALID_TOPICS]
if not valid_selected:
valid_selected = ["FALLBACK"]
return {"chosen_topics": valid_selected}
def classify_individual_topic(topic_name: str):
"""A dynamic factory function using Pydantic tracking to guarantee high recall with zero text pollution."""
def node_runner(state: MultiTopicState):
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# Force strict structured output format
structured_llm = llm.with_structured_output(KeywordExtraction)
target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
prompt = ChatPromptTemplate.from_messages([
("system", f"You are a specialist in {topic_name} data mapping. Extract all exact matching keyword pathways present in the provided valid path list. Do not modify or truncate the pathway strings."),
("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nExtract matching pathways as a structured array list.")
])
# Enforce JSON list output natively
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
raw_keywords = [k.strip() for k in result.keywords if k.strip()]
# Immediate validation pass inside the node branch
valid_set = set()
for line in target_sub_tree.split("\n"):
if line.strip():
path = line.replace("- ", "").strip()
valid_set.add(path)
valid_kws = [kw for kw in raw_keywords if kw in valid_set]
invalid_kws = [kw for kw in raw_keywords if kw not in valid_set]
return {"predicted_keywords": valid_kws, "invalid_keywords": invalid_kws}
return node_runner
def parallel_router(state: MultiTopicState) -> List[str]:
return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
# Assemble the Graph
workflow = StateGraph(MultiTopicState)
workflow.add_node("top_router", route_multi_topic)
for topic in VALID_TOPICS:
node_id = f"classify_{topic.lower()}"
workflow.add_node(node_id, classify_individual_topic(topic))
workflow.add_node("classify_fallback", lambda state: {"predicted_keywords": []})
workflow.add_edge(START, "top_router")
workflow.add_conditional_edges(
"top_router",
parallel_router,
{f"classify_{t.lower()}": f"classify_{t.lower()}" for t in VALID_TOPICS} | {"classify_fallback": "classify_fallback"}
)
for topic in VALID_TOPICS:
workflow.add_edge(f"classify_{topic.lower()}", END)
workflow.add_edge("classify_fallback", END)
app = workflow.compile()
# ==========================================================
# 3. GRADIO USER INTERFACE
# ==========================================================
def run_agent_classifier(title, abstract):
if not title or not abstract:
return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
inputs = {"title": title, "abstract": abstract}
output = app.invoke(inputs)
chosen_topics = output.get("chosen_topics", [])
predicted_kws = output.get("predicted_keywords", [])
invalid_kws = output.get("invalid_keywords", [])
topics_str = ", ".join(chosen_topics)
formatted_keywords = []
for kw in predicted_kws:
matched_topic = "UNKNOWN"
for topic in chosen_topics:
sub_tree_text = SUB_TREE_LOOKUP.get(topic, "")
if kw in sub_tree_text:
matched_topic = topic
break
formatted_keywords.append(f"• {matched_topic} > {kw}")
if not formatted_keywords:
keywords_str = "No explicit keywords mapped."
else:
keywords_str = "\n".join(formatted_keywords)
if not invalid_kws:
invalid_str = "None! The agent validation pass achieved 100% data integrity."
else:
invalid_str = "\n".join([f"⚠ Caught & Removed: {ikw}" for ikw in invalid_kws])
return topics_str, keywords_str, invalid_str
demo = gr.Interface(
fn=run_agent_classifier,
inputs=[
gr.Textbox(label="Journal Article Title", placeholder="Enter article title here...", lines=1),
gr.Textbox(label="Abstract / Body Text", placeholder="Paste abstract description here...", lines=5)
],
outputs=[
gr.Textbox(label="Routed Multi-Topic Domains"),
gr.Textbox(label="Verified GCMD Keywords Extracted (Formatted with Topics)"),
gr.Textbox(label="Hallucinated/Invalid Keywords Caught and Removed")
],
title="GCMD Science Keyword Classifier Agent",
description="Proof of Concept using LangGraph and LangChain. Routes articles concurrently across science domains and runs isolated self-validation routines.",
examples=[
["El Niño Southern Oscillation Driving Anomalous Atmospheric Evaporation", "We observe how rising sea surface temperatures across equatorial waters directly trigger increased low-level cloud formation and accelerated surface winds."]
]
)
if __name__ == "__main__":
demo.launch()
|