igerasimov's picture
perf: keywords nodes use Pydantic structured output
c7e6f79
Raw
History Blame Contribute Delete
8.39 kB
import os
import json
import re
import gradio as gr
from typing import List, TypedDict, Annotated
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# ==========================================================
# 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
# ==========================================================
with open("gcmd_hierarchy.json", "r") as f:
gcmd_data = json.load(f)
def build_gcmd_indices(gcmd_json):
topic_list = []
sub_tree_indices = {}
topics = gcmd_json.get("children", [])
for topic_node in topics:
topic_name = topic_node.get("name", "").upper()
topic_list.append(topic_name)
collected_paths = []
def recurse_sub_tree(node, current_path=""):
node_name = node.get("name", "")
node_path = f"{current_path} > {node_name}" if current_path else node_name
if "Variable" in node.get("level", "") or not node.get("children"):
collected_paths.append(node_path)
for child in node.get("children", []):
recurse_sub_tree(child, node_path)
for term_node in topic_node.get("children", []):
recurse_sub_tree(term_node, current_path="")
sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
return topic_list, sub_tree_indices
VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
# ==========================================================
# 2. LANGGRAPH WORKFLOW WITH DUAL STRUCTURED OUTPUTS
# ==========================================================
def merge_lists(left: list, right: list) -> list:
return list(set((left or []) + (right or [])))
class MultiTopicState(TypedDict):
title: str
abstract: str
chosen_topics: List[str]
predicted_keywords: Annotated[List[str], merge_lists]
invalid_keywords: Annotated[List[str], merge_lists]
# Pydantic schema for Step 1
class TopicsChoice(BaseModel):
topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
# New Pydantic schema for Step 2 (Guarantees zero mumbling while preserving full recall)
class KeywordExtraction(BaseModel):
keywords: List[str] = Field(description="List of exact matching keyword pathways from the provided text block. Return an empty list if nothing matches.")
def route_multi_topic(state: MultiTopicState):
"""Step 1: Identify ALL relevant high-level topics (Restored to be healthily inclusive)."""
llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output(TopicsChoice)
prompt = ChatPromptTemplate.from_messages([
("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"),
("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.")
])
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
valid_selected = [t.upper().strip() for t in result.topics if t.upper().strip() in VALID_TOPICS]
if not valid_selected:
valid_selected = ["FALLBACK"]
return {"chosen_topics": valid_selected}
def classify_individual_topic(topic_name: str):
"""A dynamic factory function using Pydantic tracking to guarantee high recall with zero text pollution."""
def node_runner(state: MultiTopicState):
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# Force strict structured output format
structured_llm = llm.with_structured_output(KeywordExtraction)
target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
prompt = ChatPromptTemplate.from_messages([
("system", f"You are a specialist in {topic_name} data mapping. Extract all exact matching keyword pathways present in the provided valid path list. Do not modify or truncate the pathway strings."),
("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nExtract matching pathways as a structured array list.")
])
# Enforce JSON list output natively
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
raw_keywords = [k.strip() for k in result.keywords if k.strip()]
# Immediate validation pass inside the node branch
valid_set = set()
for line in target_sub_tree.split("\n"):
if line.strip():
path = line.replace("- ", "").strip()
valid_set.add(path)
valid_kws = [kw for kw in raw_keywords if kw in valid_set]
invalid_kws = [kw for kw in raw_keywords if kw not in valid_set]
return {"predicted_keywords": valid_kws, "invalid_keywords": invalid_kws}
return node_runner
def parallel_router(state: MultiTopicState) -> List[str]:
return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
# Assemble the Graph
workflow = StateGraph(MultiTopicState)
workflow.add_node("top_router", route_multi_topic)
for topic in VALID_TOPICS:
node_id = f"classify_{topic.lower()}"
workflow.add_node(node_id, classify_individual_topic(topic))
workflow.add_node("classify_fallback", lambda state: {"predicted_keywords": []})
workflow.add_edge(START, "top_router")
workflow.add_conditional_edges(
"top_router",
parallel_router,
{f"classify_{t.lower()}": f"classify_{t.lower()}" for t in VALID_TOPICS} | {"classify_fallback": "classify_fallback"}
)
for topic in VALID_TOPICS:
workflow.add_edge(f"classify_{topic.lower()}", END)
workflow.add_edge("classify_fallback", END)
app = workflow.compile()
# ==========================================================
# 3. GRADIO USER INTERFACE
# ==========================================================
def run_agent_classifier(title, abstract):
if not title or not abstract:
return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
inputs = {"title": title, "abstract": abstract}
output = app.invoke(inputs)
chosen_topics = output.get("chosen_topics", [])
predicted_kws = output.get("predicted_keywords", [])
invalid_kws = output.get("invalid_keywords", [])
topics_str = ", ".join(chosen_topics)
formatted_keywords = []
for kw in predicted_kws:
matched_topic = "UNKNOWN"
for topic in chosen_topics:
sub_tree_text = SUB_TREE_LOOKUP.get(topic, "")
if kw in sub_tree_text:
matched_topic = topic
break
formatted_keywords.append(f"• {matched_topic} > {kw}")
if not formatted_keywords:
keywords_str = "No explicit keywords mapped."
else:
keywords_str = "\n".join(formatted_keywords)
if not invalid_kws:
invalid_str = "None! The agent validation pass achieved 100% data integrity."
else:
invalid_str = "\n".join([f"⚠ Caught & Removed: {ikw}" for ikw in invalid_kws])
return topics_str, keywords_str, invalid_str
demo = gr.Interface(
fn=run_agent_classifier,
inputs=[
gr.Textbox(label="Journal Article Title", placeholder="Enter article title here...", lines=1),
gr.Textbox(label="Abstract / Body Text", placeholder="Paste abstract description here...", lines=5)
],
outputs=[
gr.Textbox(label="Routed Multi-Topic Domains"),
gr.Textbox(label="Verified GCMD Keywords Extracted (Formatted with Topics)"),
gr.Textbox(label="Hallucinated/Invalid Keywords Caught and Removed")
],
title="GCMD Science Keyword Classifier Agent",
description="Proof of Concept using LangGraph and LangChain. Routes articles concurrently across science domains and runs isolated self-validation routines.",
examples=[
["El Niño Southern Oscillation Driving Anomalous Atmospheric Evaporation", "We observe how rising sea surface temperatures across equatorial waters directly trigger increased low-level cloud formation and accelerated surface winds."]
]
)
if __name__ == "__main__":
demo.launch()