Commit ·
1885bb3
1
Parent(s): 32c52f2
fix: enforce 'NONE' responses to stop conversational pollution in invalid keywords
Browse files
app.py
CHANGED
|
@@ -9,59 +9,42 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
| 9 |
from langchain_openai import ChatOpenAI
|
| 10 |
|
| 11 |
# ==========================================================
|
| 12 |
-
# 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
|
| 13 |
# ==========================================================
|
| 14 |
with open("gcmd_hierarchy.json", "r") as f:
|
| 15 |
gcmd_data = json.load(f)
|
| 16 |
|
| 17 |
def build_gcmd_indices(gcmd_json):
|
| 18 |
-
"""
|
| 19 |
-
Parses the GCMD hierarchy exactly like the notebook but strips
|
| 20 |
-
heavy definitions to prevent 429 Token Rate Limit errors.
|
| 21 |
-
"""
|
| 22 |
topic_list = []
|
| 23 |
sub_tree_indices = {}
|
| 24 |
-
|
| 25 |
-
# The root node level is "Category" (EARTH SCIENCE)
|
| 26 |
topics = gcmd_json.get("children", [])
|
| 27 |
|
| 28 |
for topic_node in topics:
|
| 29 |
topic_name = topic_node.get("name", "").upper()
|
| 30 |
topic_list.append(topic_name)
|
| 31 |
-
|
| 32 |
-
# Collect paths inside this specific topic
|
| 33 |
collected_paths = []
|
| 34 |
|
| 35 |
def recurse_sub_tree(node, current_path=""):
|
| 36 |
node_name = node.get("name", "")
|
| 37 |
-
# Append current node name to path
|
| 38 |
node_path = f"{current_path} > {node_name}" if current_path else node_name
|
| 39 |
-
|
| 40 |
-
# If it's a Variable or leaf term, save the path only
|
| 41 |
if "Variable" in node.get("level", "") or not node.get("children"):
|
| 42 |
collected_paths.append(node_path)
|
| 43 |
-
|
| 44 |
-
# Dig deeper into children
|
| 45 |
for child in node.get("children", []):
|
| 46 |
recurse_sub_tree(child, node_path)
|
| 47 |
|
| 48 |
-
# Start the recursion from the children of this Topic (the 'Terms')
|
| 49 |
for term_node in topic_node.get("children", []):
|
| 50 |
recurse_sub_tree(term_node, current_path="")
|
| 51 |
|
| 52 |
-
# Format the collected paths into a clean lookup layout for the prompt
|
| 53 |
sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
|
| 54 |
-
|
| 55 |
return topic_list, sub_tree_indices
|
| 56 |
|
| 57 |
VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
|
| 58 |
|
| 59 |
# ==========================================================
|
| 60 |
-
# 2.
|
| 61 |
# ==========================================================
|
| 62 |
|
| 63 |
def merge_lists(left: list, right: list) -> list:
|
| 64 |
-
"""A reducer function that merges list contents across parallel branches."""
|
| 65 |
return list(set((left or []) + (right or [])))
|
| 66 |
|
| 67 |
class MultiTopicState(TypedDict):
|
|
@@ -75,13 +58,17 @@ class TopicsChoice(BaseModel):
|
|
| 75 |
topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
|
| 76 |
|
| 77 |
def route_multi_topic(state: MultiTopicState):
|
| 78 |
-
"""Step 1: Identify
|
| 79 |
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 80 |
structured_llm = llm.with_structured_output(TopicsChoice)
|
| 81 |
|
| 82 |
prompt = ChatPromptTemplate.from_messages([
|
| 83 |
-
("system",
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
])
|
| 86 |
|
| 87 |
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
|
|
@@ -93,21 +80,30 @@ def route_multi_topic(state: MultiTopicState):
|
|
| 93 |
return {"chosen_topics": valid_selected}
|
| 94 |
|
| 95 |
def classify_individual_topic(topic_name: str):
|
| 96 |
-
"""A dynamic factory function
|
| 97 |
|
| 98 |
def node_runner(state: MultiTopicState):
|
| 99 |
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 100 |
target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
|
| 101 |
|
| 102 |
prompt = ChatPromptTemplate.from_messages([
|
| 103 |
-
("system",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nReturn exact matching entries as a comma-separated list.")
|
| 105 |
])
|
| 106 |
|
| 107 |
response = llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
|
| 108 |
-
raw_keywords = [k.strip() for k in response.content.split(",") if k.strip()]
|
| 109 |
|
| 110 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
valid_set = set()
|
| 112 |
for line in target_sub_tree.split("\n"):
|
| 113 |
if line.strip():
|
|
@@ -122,7 +118,6 @@ def classify_individual_topic(topic_name: str):
|
|
| 122 |
return node_runner
|
| 123 |
|
| 124 |
def parallel_router(state: MultiTopicState) -> List[str]:
|
| 125 |
-
"""Tells LangGraph to trigger multiple sub-nodes simultaneously based on state."""
|
| 126 |
return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
|
| 127 |
|
| 128 |
# Assemble the Graph
|
|
@@ -149,15 +144,13 @@ workflow.add_edge("classify_fallback", END)
|
|
| 149 |
app = workflow.compile()
|
| 150 |
|
| 151 |
# ==========================================================
|
| 152 |
-
# 3. GRADIO USER INTERFACE
|
| 153 |
# ==========================================================
|
| 154 |
def run_agent_classifier(title, abstract):
|
| 155 |
if not title or not abstract:
|
| 156 |
return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
|
| 157 |
|
| 158 |
inputs = {"title": title, "abstract": abstract}
|
| 159 |
-
|
| 160 |
-
# Execute graph synchronously
|
| 161 |
output = app.invoke(inputs)
|
| 162 |
|
| 163 |
chosen_topics = output.get("chosen_topics", [])
|
|
@@ -166,7 +159,6 @@ def run_agent_classifier(title, abstract):
|
|
| 166 |
|
| 167 |
topics_str = ", ".join(chosen_topics)
|
| 168 |
|
| 169 |
-
# Format output keywords to prepend their corresponding parent Topic
|
| 170 |
formatted_keywords = []
|
| 171 |
for kw in predicted_kws:
|
| 172 |
matched_topic = "UNKNOWN"
|
|
@@ -182,7 +174,6 @@ def run_agent_classifier(title, abstract):
|
|
| 182 |
else:
|
| 183 |
keywords_str = "\n".join(formatted_keywords)
|
| 184 |
|
| 185 |
-
# Format separate display for caught and filtered hallucinations
|
| 186 |
if not invalid_kws:
|
| 187 |
invalid_str = "None! The agent validation pass achieved 100% data integrity."
|
| 188 |
else:
|
|
|
|
| 9 |
from langchain_openai import ChatOpenAI
|
| 10 |
|
| 11 |
# ==========================================================
|
| 12 |
+
# 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
|
| 13 |
# ==========================================================
|
| 14 |
with open("gcmd_hierarchy.json", "r") as f:
|
| 15 |
gcmd_data = json.load(f)
|
| 16 |
|
| 17 |
def build_gcmd_indices(gcmd_json):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
topic_list = []
|
| 19 |
sub_tree_indices = {}
|
|
|
|
|
|
|
| 20 |
topics = gcmd_json.get("children", [])
|
| 21 |
|
| 22 |
for topic_node in topics:
|
| 23 |
topic_name = topic_node.get("name", "").upper()
|
| 24 |
topic_list.append(topic_name)
|
|
|
|
|
|
|
| 25 |
collected_paths = []
|
| 26 |
|
| 27 |
def recurse_sub_tree(node, current_path=""):
|
| 28 |
node_name = node.get("name", "")
|
|
|
|
| 29 |
node_path = f"{current_path} > {node_name}" if current_path else node_name
|
|
|
|
|
|
|
| 30 |
if "Variable" in node.get("level", "") or not node.get("children"):
|
| 31 |
collected_paths.append(node_path)
|
|
|
|
|
|
|
| 32 |
for child in node.get("children", []):
|
| 33 |
recurse_sub_tree(child, node_path)
|
| 34 |
|
|
|
|
| 35 |
for term_node in topic_node.get("children", []):
|
| 36 |
recurse_sub_tree(term_node, current_path="")
|
| 37 |
|
|
|
|
| 38 |
sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
|
|
|
|
| 39 |
return topic_list, sub_tree_indices
|
| 40 |
|
| 41 |
VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
|
| 42 |
|
| 43 |
# ==========================================================
|
| 44 |
+
# 2. UPDATED LANGGRAPH WORKFLOW WITH NULL-ENFORCEMENT
|
| 45 |
# ==========================================================
|
| 46 |
|
| 47 |
def merge_lists(left: list, right: list) -> list:
|
|
|
|
| 48 |
return list(set((left or []) + (right or [])))
|
| 49 |
|
| 50 |
class MultiTopicState(TypedDict):
|
|
|
|
| 58 |
topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
|
| 59 |
|
| 60 |
def route_multi_topic(state: MultiTopicState):
|
| 61 |
+
"""Step 1: Identify strictly relevant high-level topics (Tightened to prevent over-routing)."""
|
| 62 |
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 63 |
structured_llm = llm.with_structured_output(TopicsChoice)
|
| 64 |
|
| 65 |
prompt = ChatPromptTemplate.from_messages([
|
| 66 |
+
("system", (
|
| 67 |
+
f"You are an expert science cataloger. Identify only the core major topic areas that "
|
| 68 |
+
f"DIRECTLY and primary apply to this paper. Do NOT select topics that are only tangentially "
|
| 69 |
+
f"referenced or inferred. Choose ONLY from: {', '.join(VALID_TOPICS)}"
|
| 70 |
+
)),
|
| 71 |
+
("user", "Title: {title}\nAbstract: {abstract}\n\nSelect the highly relevant Topics as a structured list.")
|
| 72 |
])
|
| 73 |
|
| 74 |
result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
|
|
|
|
| 80 |
return {"chosen_topics": valid_selected}
|
| 81 |
|
| 82 |
def classify_individual_topic(topic_name: str):
|
| 83 |
+
"""A dynamic factory function with strict enforcement against conversational prose."""
|
| 84 |
|
| 85 |
def node_runner(state: MultiTopicState):
|
| 86 |
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 87 |
target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
|
| 88 |
|
| 89 |
prompt = ChatPromptTemplate.from_messages([
|
| 90 |
+
("system", (
|
| 91 |
+
f"You are a specialist in {topic_name} data mapping. Extract exact keyword pathways present in the provided list.\n"
|
| 92 |
+
f"CRITICAL RULES:\n"
|
| 93 |
+
f"1. If absolutely no keywords from the valid path list match this paper, reply with exactly the word 'NONE'.\n"
|
| 94 |
+
f"2. Do NOT write sentences, do NOT explain your reasoning, and do NOT say 'there are no matching entries'. Your output must only be a comma-separated list of keywords, or the single token 'NONE'."
|
| 95 |
+
)),
|
| 96 |
("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nReturn exact matching entries as a comma-separated list.")
|
| 97 |
])
|
| 98 |
|
| 99 |
response = llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
|
|
|
|
| 100 |
|
| 101 |
+
# Split and clean tokens, explicitly ignoring any 'NONE' or empty responses
|
| 102 |
+
raw_keywords = [
|
| 103 |
+
k.strip() for k in response.content.split(",")
|
| 104 |
+
if k.strip() and k.strip().upper() != "NONE"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
valid_set = set()
|
| 108 |
for line in target_sub_tree.split("\n"):
|
| 109 |
if line.strip():
|
|
|
|
| 118 |
return node_runner
|
| 119 |
|
| 120 |
def parallel_router(state: MultiTopicState) -> List[str]:
|
|
|
|
| 121 |
return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
|
| 122 |
|
| 123 |
# Assemble the Graph
|
|
|
|
| 144 |
app = workflow.compile()
|
| 145 |
|
| 146 |
# ==========================================================
|
| 147 |
+
# 3. GRADIO USER INTERFACE
|
| 148 |
# ==========================================================
|
| 149 |
def run_agent_classifier(title, abstract):
|
| 150 |
if not title or not abstract:
|
| 151 |
return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
|
| 152 |
|
| 153 |
inputs = {"title": title, "abstract": abstract}
|
|
|
|
|
|
|
| 154 |
output = app.invoke(inputs)
|
| 155 |
|
| 156 |
chosen_topics = output.get("chosen_topics", [])
|
|
|
|
| 159 |
|
| 160 |
topics_str = ", ".join(chosen_topics)
|
| 161 |
|
|
|
|
| 162 |
formatted_keywords = []
|
| 163 |
for kw in predicted_kws:
|
| 164 |
matched_topic = "UNKNOWN"
|
|
|
|
| 174 |
else:
|
| 175 |
keywords_str = "\n".join(formatted_keywords)
|
| 176 |
|
|
|
|
| 177 |
if not invalid_kws:
|
| 178 |
invalid_str = "None! The agent validation pass achieved 100% data integrity."
|
| 179 |
else:
|