igerasimov commited on
Commit
1885bb3
·
1 Parent(s): 32c52f2

fix: enforce 'NONE' responses to stop conversational pollution in invalid keywords

Browse files
Files changed (1) hide show
  1. app.py +23 -32
app.py CHANGED
@@ -9,59 +9,42 @@ from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_openai import ChatOpenAI
10
 
11
  # ==========================================================
12
- # 1. LOAD DATA & OPTIMIZED INDEX GENERATOR (REDUCES TOKENS)
13
  # ==========================================================
14
  with open("gcmd_hierarchy.json", "r") as f:
15
  gcmd_data = json.load(f)
16
 
17
  def build_gcmd_indices(gcmd_json):
18
- """
19
- Parses the GCMD hierarchy exactly like the notebook but strips
20
- heavy definitions to prevent 429 Token Rate Limit errors.
21
- """
22
  topic_list = []
23
  sub_tree_indices = {}
24
-
25
- # The root node level is "Category" (EARTH SCIENCE)
26
  topics = gcmd_json.get("children", [])
27
 
28
  for topic_node in topics:
29
  topic_name = topic_node.get("name", "").upper()
30
  topic_list.append(topic_name)
31
-
32
- # Collect paths inside this specific topic
33
  collected_paths = []
34
 
35
  def recurse_sub_tree(node, current_path=""):
36
  node_name = node.get("name", "")
37
- # Append current node name to path
38
  node_path = f"{current_path} > {node_name}" if current_path else node_name
39
-
40
- # If it's a Variable or leaf term, save the path only
41
  if "Variable" in node.get("level", "") or not node.get("children"):
42
  collected_paths.append(node_path)
43
-
44
- # Dig deeper into children
45
  for child in node.get("children", []):
46
  recurse_sub_tree(child, node_path)
47
 
48
- # Start the recursion from the children of this Topic (the 'Terms')
49
  for term_node in topic_node.get("children", []):
50
  recurse_sub_tree(term_node, current_path="")
51
 
52
- # Format the collected paths into a clean lookup layout for the prompt
53
  sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
54
-
55
  return topic_list, sub_tree_indices
56
 
57
  VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
58
 
59
  # ==========================================================
60
- # 2. VERBATIM WORKFLOW FROM YOUR NOTEBOOK
61
  # ==========================================================
62
 
63
  def merge_lists(left: list, right: list) -> list:
64
- """A reducer function that merges list contents across parallel branches."""
65
  return list(set((left or []) + (right or [])))
66
 
67
  class MultiTopicState(TypedDict):
@@ -75,13 +58,17 @@ class TopicsChoice(BaseModel):
75
  topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
76
 
77
  def route_multi_topic(state: MultiTopicState):
78
- """Step 1: Identify ALL relevant high-level topics for the paper."""
79
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
80
  structured_llm = llm.with_structured_output(TopicsChoice)
81
 
82
  prompt = ChatPromptTemplate.from_messages([
83
- ("system", f"You are an expert science cataloger. Identify ALL relevant major topic areas that apply to this paper. Choose ONLY from: {', '.join(VALID_TOPICS)}"),
84
- ("user", "Title: {title}\nAbstract: {abstract}\n\nSelect all relevant Topics as a structured list.")
 
 
 
 
85
  ])
86
 
87
  result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
@@ -93,21 +80,30 @@ def route_multi_topic(state: MultiTopicState):
93
  return {"chosen_topics": valid_selected}
94
 
95
  def classify_individual_topic(topic_name: str):
96
- """A dynamic factory function that returns a custom node runner for a specific topic."""
97
 
98
  def node_runner(state: MultiTopicState):
99
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
100
  target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
101
 
102
  prompt = ChatPromptTemplate.from_messages([
103
- ("system", f"You are a specialist in {topic_name} data mapping. Extract exact keyword pathways present in the provided list."),
 
 
 
 
 
104
  ("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nReturn exact matching entries as a comma-separated list.")
105
  ])
106
 
107
  response = llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
108
- raw_keywords = [k.strip() for k in response.content.split(",") if k.strip()]
109
 
110
- # Immediate validation pass inside the node branch
 
 
 
 
 
111
  valid_set = set()
112
  for line in target_sub_tree.split("\n"):
113
  if line.strip():
@@ -122,7 +118,6 @@ def classify_individual_topic(topic_name: str):
122
  return node_runner
123
 
124
  def parallel_router(state: MultiTopicState) -> List[str]:
125
- """Tells LangGraph to trigger multiple sub-nodes simultaneously based on state."""
126
  return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
127
 
128
  # Assemble the Graph
@@ -149,15 +144,13 @@ workflow.add_edge("classify_fallback", END)
149
  app = workflow.compile()
150
 
151
  # ==========================================================
152
- # 3. GRADIO USER INTERFACE (WITH PRESENTATION FORMATTING)
153
  # ==========================================================
154
  def run_agent_classifier(title, abstract):
155
  if not title or not abstract:
156
  return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
157
 
158
  inputs = {"title": title, "abstract": abstract}
159
-
160
- # Execute graph synchronously
161
  output = app.invoke(inputs)
162
 
163
  chosen_topics = output.get("chosen_topics", [])
@@ -166,7 +159,6 @@ def run_agent_classifier(title, abstract):
166
 
167
  topics_str = ", ".join(chosen_topics)
168
 
169
- # Format output keywords to prepend their corresponding parent Topic
170
  formatted_keywords = []
171
  for kw in predicted_kws:
172
  matched_topic = "UNKNOWN"
@@ -182,7 +174,6 @@ def run_agent_classifier(title, abstract):
182
  else:
183
  keywords_str = "\n".join(formatted_keywords)
184
 
185
- # Format separate display for caught and filtered hallucinations
186
  if not invalid_kws:
187
  invalid_str = "None! The agent validation pass achieved 100% data integrity."
188
  else:
 
9
  from langchain_openai import ChatOpenAI
10
 
11
  # ==========================================================
12
+ # 1. LOAD DATA & OPTIMIZED INDEX GENERATOR
13
  # ==========================================================
14
  with open("gcmd_hierarchy.json", "r") as f:
15
  gcmd_data = json.load(f)
16
 
17
  def build_gcmd_indices(gcmd_json):
 
 
 
 
18
  topic_list = []
19
  sub_tree_indices = {}
 
 
20
  topics = gcmd_json.get("children", [])
21
 
22
  for topic_node in topics:
23
  topic_name = topic_node.get("name", "").upper()
24
  topic_list.append(topic_name)
 
 
25
  collected_paths = []
26
 
27
  def recurse_sub_tree(node, current_path=""):
28
  node_name = node.get("name", "")
 
29
  node_path = f"{current_path} > {node_name}" if current_path else node_name
 
 
30
  if "Variable" in node.get("level", "") or not node.get("children"):
31
  collected_paths.append(node_path)
 
 
32
  for child in node.get("children", []):
33
  recurse_sub_tree(child, node_path)
34
 
 
35
  for term_node in topic_node.get("children", []):
36
  recurse_sub_tree(term_node, current_path="")
37
 
 
38
  sub_tree_indices[topic_name] = "\n".join([f"- {path}" for path in collected_paths])
 
39
  return topic_list, sub_tree_indices
40
 
41
  VALID_TOPICS, SUB_TREE_LOOKUP = build_gcmd_indices(gcmd_data)
42
 
43
  # ==========================================================
44
+ # 2. UPDATED LANGGRAPH WORKFLOW WITH NULL-ENFORCEMENT
45
  # ==========================================================
46
 
47
  def merge_lists(left: list, right: list) -> list:
 
48
  return list(set((left or []) + (right or [])))
49
 
50
  class MultiTopicState(TypedDict):
 
58
  topics: List[str] = Field(description="List of matching topic areas from the allowed dataset.")
59
 
60
  def route_multi_topic(state: MultiTopicState):
61
+ """Step 1: Identify strictly relevant high-level topics (Tightened to prevent over-routing)."""
62
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
63
  structured_llm = llm.with_structured_output(TopicsChoice)
64
 
65
  prompt = ChatPromptTemplate.from_messages([
66
+ ("system", (
67
+ f"You are an expert science cataloger. Identify only the core major topic areas that "
68
+ f"DIRECTLY and primary apply to this paper. Do NOT select topics that are only tangentially "
69
+ f"referenced or inferred. Choose ONLY from: {', '.join(VALID_TOPICS)}"
70
+ )),
71
+ ("user", "Title: {title}\nAbstract: {abstract}\n\nSelect the highly relevant Topics as a structured list.")
72
  ])
73
 
74
  result = structured_llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"]))
 
80
  return {"chosen_topics": valid_selected}
81
 
82
  def classify_individual_topic(topic_name: str):
83
+ """A dynamic factory function with strict enforcement against conversational prose."""
84
 
85
  def node_runner(state: MultiTopicState):
86
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
87
  target_sub_tree = SUB_TREE_LOOKUP.get(topic_name, "")
88
 
89
  prompt = ChatPromptTemplate.from_messages([
90
+ ("system", (
91
+ f"You are a specialist in {topic_name} data mapping. Extract exact keyword pathways present in the provided list.\n"
92
+ f"CRITICAL RULES:\n"
93
+ f"1. If absolutely no keywords from the valid path list match this paper, reply with exactly the word 'NONE'.\n"
94
+ f"2. Do NOT write sentences, do NOT explain your reasoning, and do NOT say 'there are no matching entries'. Your output must only be a comma-separated list of keywords, or the single token 'NONE'."
95
+ )),
96
  ("user", "Title: {title}\nAbstract: {abstract}\n\nValid Paths:\n{sub_tree}\n\nReturn exact matching entries as a comma-separated list.")
97
  ])
98
 
99
  response = llm.invoke(prompt.format(title=state["title"], abstract=state["abstract"], sub_tree=target_sub_tree))
 
100
 
101
+ # Split and clean tokens, explicitly ignoring any 'NONE' or empty responses
102
+ raw_keywords = [
103
+ k.strip() for k in response.content.split(",")
104
+ if k.strip() and k.strip().upper() != "NONE"
105
+ ]
106
+
107
  valid_set = set()
108
  for line in target_sub_tree.split("\n"):
109
  if line.strip():
 
118
  return node_runner
119
 
120
  def parallel_router(state: MultiTopicState) -> List[str]:
 
121
  return [f"classify_{topic.lower()}" for topic in state["chosen_topics"]]
122
 
123
  # Assemble the Graph
 
144
  app = workflow.compile()
145
 
146
  # ==========================================================
147
+ # 3. GRADIO USER INTERFACE
148
  # ==========================================================
149
  def run_agent_classifier(title, abstract):
150
  if not title or not abstract:
151
  return "Please fill out both Title and Abstract fields.", "N/A", "N/A"
152
 
153
  inputs = {"title": title, "abstract": abstract}
 
 
154
  output = app.invoke(inputs)
155
 
156
  chosen_topics = output.get("chosen_topics", [])
 
159
 
160
  topics_str = ", ".join(chosen_topics)
161
 
 
162
  formatted_keywords = []
163
  for kw in predicted_kws:
164
  matched_topic = "UNKNOWN"
 
174
  else:
175
  keywords_str = "\n".join(formatted_keywords)
176
 
 
177
  if not invalid_kws:
178
  invalid_str = "None! The agent validation pass achieved 100% data integrity."
179
  else: