import os import json import re import gradio as gr from typing import List, TypedDict, Annotated from pydantic import BaseModel, Field from langgraph.graph import StateGraph, START, END from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI # ========================================================== # 1. LOAD DATA & OPTIMIZED INDEX GENERATOR # ========================================================== with open("gcmd_hierarchy.json", "r") as f: gcmd_data = json.load(f) def build_gcmd_indices(gcmd_json): topic_list = [] sub_tree_indices = {} topics = gcmd_json.get("children", []) for topic_node in topics: topic_name = topic_node.get("name", "").upper() topic_list.append(topic_name) collected_paths = [] def recurse_sub_tree(node, current_path=""): node_name = node.get("name", "") node_path = f"{current_path} > {node_name}" if current_path else node_name if "Variable" in node.get("level", "") or not node.get("children"): collected_paths.append(node_path) for child in node.get("children", []): recurse_sub_tree(child, node_path) for term_node in topic_node.get("children", []): recurse_sub_tree(term_node, current_path="") sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths]) return topic_list, sub_tree_indices VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data) # ========================================================== # 2. LANGGRAPH WORKFLOW WITH DUAL STRUCTURED OUTPUTS # ========================================================== def merge_lists(left: list, right: list) -> list: return list(set((left or []) + (right or []))) class MultiTopicState(TypedDict): title: str abstract: str chosen_topics: List[str] predicted_keywords: Annotated[List[str], merge_lists] invalid_keywords: Annotated[List[str], merge_lists] # Pydantic schema for Step 1 class TopicsChoice(BaseModel): topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.") # New Pydantic schema for Step 2 (Guarantees zero mumbling while preserving full recall) class KeywordExtraction(BaseModel): keywords: List[str] = Field(description="List of exact matching keyword pathways from the provided text block. Return an empty list if nothing matches.") def route_multi_topic(state: MultiTopicState): """Step 1: Identify ALL relevant high-level topics (Restored to be healthily inclusive).""" llm = ChatOpenAI(model="gpt-4o", temperature=0) structured_llm = llm.with_structured_output(TopicsChoice) prompt = ChatPromptTemplate.from_messages([ ("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"), ("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.") ]) result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"])) valid_selected = [t.upper().strip() for t in result.topics if t.upper().strip() in VALID_TOPICS] if not valid_selected: valid_selected = ["FALLBACK"] return {"chosen_topics": valid_selected} def classify_individual_topic(topic_name: str): """A dynamic factory function using Pydantic tracking to guarantee high recall with zero text pollution.""" def node_runner(state: MultiTopicState): llm = ChatOpenAI(model="gpt-4o", temperature=0) # Force strict structured output format structured_llm = llm.with_structured_output(KeywordExtraction) target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "") prompt = ChatPromptTemplate.from_messages([ ("system", f"You are a specialist in {topic_name} data mapping. Extract all exact matching keyword pathways present in the provided valid path list. Do not modify or truncate the pathway strings."), ("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nExtract matching pathways as a structured array list.") ]) # Enforce JSON list output natively result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree)) raw_keywords = [k.strip() for k in result.keywords if k.strip()] # Immediate validation pass inside the node branch valid_set = set() for line in target_sub_tree.split("\n"): if line.strip(): path = line.replace("- ", "").strip() valid_set.add(path) valid_kws = [kw for kw in raw_keywords if kw in valid_set] invalid_kws = [kw for kw in raw_keywords if kw not in valid_set] return {"predicted_keywords": valid_kws, "invalid_keywords": invalid_kws} return node_runner def parallel_router(state: MultiTopicState) -> List[str]: return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]] # Assemble the Graph workflow = StateGraph(MultiTopicState) workflow.add_node("top_router", route_multi_topic) for topic in VALID_TOPICS: node_id = f"classify_{topic.lower()}" workflow.add_node(node_id, classify_individual_topic(topic)) workflow.add_node("classify_fallback", lambda state: {"predicted_keywords": []}) workflow.add_edge(START, "top_router") workflow.add_conditional_edges( "top_router", parallel_router, {f"classify_{t.lower()}": f"classify_{t.lower()}" for t in VALID_TOPICS} | {"classify_fallback": "classify_fallback"} ) for topic in VALID_TOPICS: workflow.add_edge(f"classify_{topic.lower()}", END) workflow.add_edge("classify_fallback", END) app = workflow.compile() # ========================================================== # 3. GRADIO USER INTERFACE # ========================================================== def run_agent_classifier(title, abstract): if not title or not abstract: return "Please fill out both Title and Abstract fields.", "N/A", "N/A" inputs = {"title": title, "abstract": abstract} output = app.invoke(inputs) chosen_topics = output.get("chosen_topics", []) predicted_kws = output.get("predicted_keywords", []) invalid_kws = output.get("invalid_keywords", []) topics_str = ", ".join(chosen_topics) formatted_keywords = [] for kw in predicted_kws: matched_topic = "UNKNOWN" for topic in chosen_topics: sub_tree_text = SUB_TREE_LOOKUP.get(topic, "") if kw in sub_tree_text: matched_topic = topic break formatted_keywords.append(f"• {matched_topic} > {kw}") if not formatted_keywords: keywords_str = "No explicit keywords mapped." else: keywords_str = "\n".join(formatted_keywords) if not invalid_kws: invalid_str = "None! The agent validation pass achieved 100% data integrity." else: invalid_str = "\n".join([f"⚠ Caught & Removed: {ikw}" for ikw in invalid_kws]) return topics_str, keywords_str, invalid_str demo = gr.Interface( fn=run_agent_classifier, inputs=[ gr.Textbox(label="Journal Article Title", placeholder="Enter article title here...", lines=1), gr.Textbox(label="Abstract / Body Text", placeholder="Paste abstract description here...", lines=5) ], outputs=[ gr.Textbox(label="Routed Multi-Topic Domains"), gr.Textbox(label="Verified GCMD Keywords Extracted (Formatted with Topics)"), gr.Textbox(label="Hallucinated/Invalid Keywords Caught and Removed") ], title="GCMD Science Keyword Classifier Agent", description="Proof of Concept using LangGraph and LangChain. Routes articles concurrently across science domains and runs isolated self-validation routines.", examples=[ ["El Niño Southern Oscillation Driving Anomalous Atmospheric Evaporation", "We observe how rising sea surface temperatures across equatorial waters directly trigger increased low-level cloud formation and accelerated surface winds."] ] ) if __name__ == "__main__": demo.launch()