| import os |
| import json |
| import re |
| import gradio as gr |
| from typing import List, TypedDict, Annotated |
| from pydantic import BaseModel, Field |
| from langgraph.graph import StateGraph, START, END |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_openai import ChatOpenAI |
|
|
| |
| |
| |
| with open("gcmd_hierarchy.json", "r") as f: |
| gcmd_data = json.load(f) |
|
|
| def build_gcmd_indices(gcmd_json): |
| topic_list = [] |
| sub_tree_indices = {} |
| topics = gcmd_json.get("children", []) |
| |
| for topic_node in topics: |
| topic_name = topic_node.get("name", "").upper() |
| topic_list.append(topic_name) |
| collected_paths = [] |
| |
| def recurse_sub_tree(node, current_path=""): |
| node_name = node.get("name", "") |
| node_path = f"{current_path} > {node_name}" if current_path else node_name |
| if "Variable" in node.get("level", "") or not node.get("children"): |
| collected_paths.append(node_path) |
| for child in node.get("children", []): |
| recurse_sub_tree(child, node_path) |
|
|
| for term_node in topic_node.get("children", []): |
| recurse_sub_tree(term_node, current_path="") |
| |
| sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths]) |
| return topic_list, sub_tree_indices |
|
|
| VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data) |
|
|
| |
| |
| |
|
|
| def merge_lists(left: list, right: list) -> list: |
| return list(set((left or []) + (right or []))) |
|
|
| class MultiTopicState(TypedDict): |
| title: str |
| abstract: str |
| chosen_topics: List[str] |
| predicted_keywords: Annotated[List[str], merge_lists] |
| invalid_keywords: Annotated[List[str], merge_lists] |
|
|
| |
| class TopicsChoice(BaseModel): |
| topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.") |
|
|
| |
| class KeywordExtraction(BaseModel): |
| keywords: List[str] = Field(description="List of exact matching keyword pathways from the provided text block. Return an empty list if nothing matches.") |
|
|
| def route_multi_topic(state: MultiTopicState): |
| """Step 1: Identify ALL relevant high-level topics (Restored to be healthily inclusive).""" |
| llm = ChatOpenAI(model="gpt-4o", temperature=0) |
| structured_llm = llm.with_structured_output(TopicsChoice) |
| |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"), |
| ("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.") |
| ]) |
| |
| result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"])) |
| valid_selected = [t.upper().strip() for t in result.topics if t.upper().strip() in VALID_TOPICS] |
| |
| if not valid_selected: |
| valid_selected = ["FALLBACK"] |
| |
| return {"chosen_topics": valid_selected} |
|
|
| def classify_individual_topic(topic_name: str): |
| """A dynamic factory function using Pydantic tracking to guarantee high recall with zero text pollution.""" |
| |
| def node_runner(state: MultiTopicState): |
| llm = ChatOpenAI(model="gpt-4o", temperature=0) |
| |
| structured_llm = llm.with_structured_output(KeywordExtraction) |
| target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "") |
| |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", f"You are a specialist in {topic_name} data mapping. Extract all exact matching keyword pathways present in the provided valid path list. Do not modify or truncate the pathway strings."), |
| ("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nExtract matching pathways as a structured array list.") |
| ]) |
| |
| |
| result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree)) |
| raw_keywords = [k.strip() for k in result.keywords if k.strip()] |
| |
| |
| valid_set = set() |
| for line in target_sub_tree.split("\n"): |
| if line.strip(): |
| path = line.replace("- ", "").strip() |
| valid_set.add(path) |
| |
| valid_kws = [kw for kw in raw_keywords if kw in valid_set] |
| invalid_kws = [kw for kw in raw_keywords if kw not in valid_set] |
| |
| return {"predicted_keywords": valid_kws, "invalid_keywords": invalid_kws} |
| |
| return node_runner |
|
|
| def parallel_router(state: MultiTopicState) -> List[str]: |
| return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]] |
|
|
| |
| workflow = StateGraph(MultiTopicState) |
| workflow.add_node("top_router", route_multi_topic) |
|
|
| for topic in VALID_TOPICS: |
| node_id = f"classify_{topic.lower()}" |
| workflow.add_node(node_id, classify_individual_topic(topic)) |
|
|
| workflow.add_node("classify_fallback", lambda state: {"predicted_keywords": []}) |
| workflow.add_edge(START, "top_router") |
|
|
| workflow.add_conditional_edges( |
| "top_router", |
| parallel_router, |
| {f"classify_{t.lower()}": f"classify_{t.lower()}" for t in VALID_TOPICS} | {"classify_fallback": "classify_fallback"} |
| ) |
|
|
| for topic in VALID_TOPICS: |
| workflow.add_edge(f"classify_{topic.lower()}", END) |
| workflow.add_edge("classify_fallback", END) |
|
|
| app = workflow.compile() |
|
|
| |
| |
| |
| def run_agent_classifier(title, abstract): |
| if not title or not abstract: |
| return "Please fill out both Title and Abstract fields.", "N/A", "N/A" |
| |
| inputs = {"title": title, "abstract": abstract} |
| output = app.invoke(inputs) |
| |
| chosen_topics = output.get("chosen_topics", []) |
| predicted_kws = output.get("predicted_keywords", []) |
| invalid_kws = output.get("invalid_keywords", []) |
| |
| topics_str = ", ".join(chosen_topics) |
| |
| formatted_keywords = [] |
| for kw in predicted_kws: |
| matched_topic = "UNKNOWN" |
| for topic in chosen_topics: |
| sub_tree_text = SUB_TREE_LOOKUP.get(topic, "") |
| if kw in sub_tree_text: |
| matched_topic = topic |
| break |
| formatted_keywords.append(f"• {matched_topic} > {kw}") |
| |
| if not formatted_keywords: |
| keywords_str = "No explicit keywords mapped." |
| else: |
| keywords_str = "\n".join(formatted_keywords) |
| |
| if not invalid_kws: |
| invalid_str = "None! The agent validation pass achieved 100% data integrity." |
| else: |
| invalid_str = "\n".join([f"⚠ Caught & Removed: {ikw}" for ikw in invalid_kws]) |
| |
| return topics_str, keywords_str, invalid_str |
|
|
| demo = gr.Interface( |
| fn=run_agent_classifier, |
| inputs=[ |
| gr.Textbox(label="Journal Article Title", placeholder="Enter article title here...", lines=1), |
| gr.Textbox(label="Abstract / Body Text", placeholder="Paste abstract description here...", lines=5) |
| ], |
| outputs=[ |
| gr.Textbox(label="Routed Multi-Topic Domains"), |
| gr.Textbox(label="Verified GCMD Keywords Extracted (Formatted with Topics)"), |
| gr.Textbox(label="Hallucinated/Invalid Keywords Caught and Removed") |
| ], |
| title="GCMD Science Keyword Classifier Agent", |
| description="Proof of Concept using LangGraph and LangChain. Routes articles concurrently across science domains and runs isolated self-validation routines.", |
| examples=[ |
| ["El Niño Southern Oscillation Driving Anomalous Atmospheric Evaporation", "We observe how rising sea surface temperatures across equatorial waters directly trigger increased low-level cloud formation and accelerated surface winds."] |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|