JacobLinCool Codex commited on
Commit
73b4c3f
·
verified ·
1 Parent(s): 8fb1ae9

feat: add model tool-call contracts

Browse files

Co-authored-by: Codex <noreply@openai.com>

README.md CHANGED
@@ -63,6 +63,12 @@ source, project order, and digest before the app starts.
63
  The app exposes a `trace_artifact` Gradio API endpoint and a `JSONL` button in the UI. Both emit the same JSONL schema:
64
  a manifest row followed by one row per agent turn. `data/sample_trace.jsonl` is a checked-in, Hub-published sample trace.
65
 
 
 
 
 
 
 
66
  ## Test
67
 
68
  ```bash
 
63
  The app exposes a `trace_artifact` Gradio API endpoint and a `JSONL` button in the UI. Both emit the same JSONL schema:
64
  a manifest row followed by one row per agent turn. `data/sample_trace.jsonl` is a checked-in, Hub-published sample trace.
65
 
66
+ ## Tool-Call Contract
67
+
68
+ `/api/tool-contracts` exposes the JSON schemas intended for MiniCPM-style tool calling. `tool_contract_check` accepts a
69
+ MiniCPM XML call such as `<function name="search_projects">{"query":"lullaby audio"}</function>`, validates it against
70
+ the schemas, and returns either the valid call or a safe default call for the UI watchdog path.
71
+
72
  ## Test
73
 
74
  ```bash
app.py CHANGED
@@ -10,6 +10,7 @@ from gradio import Server
10
 
11
  from hackathon_advisor.agent import AdvisorEngine
12
  from hackathon_advisor.data import ProjectIndex
 
13
  from hackathon_advisor.trace_export import build_trace_jsonl, trace_metadata
14
 
15
 
@@ -59,6 +60,19 @@ def bootstrap() -> dict:
59
  }
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @app.api(name="trace_artifact", concurrency_limit=8)
63
  def trace_artifact(session_json: str = "{}") -> str:
64
  try:
 
10
 
11
  from hackathon_advisor.agent import AdvisorEngine
12
  from hackathon_advisor.data import ProjectIndex
13
+ from hackathon_advisor.tool_contracts import resolve_tool_call, tool_schemas
14
  from hackathon_advisor.trace_export import build_trace_jsonl, trace_metadata
15
 
16
 
 
60
  }
61
 
62
 
63
+ @app.get("/api/tool-contracts")
64
+ def tool_contracts() -> dict:
65
+ return {
66
+ "tool_count": len(tool_schemas()),
67
+ "tools": tool_schemas(),
68
+ }
69
+
70
+
71
+ @app.api(name="tool_contract_check", concurrency_limit=8)
72
+ def tool_contract_check(model_output: str, fallback_query: str = "") -> dict:
73
+ return resolve_tool_call(model_output, fallback_query=fallback_query).to_dict()
74
+
75
+
76
  @app.api(name="trace_artifact", concurrency_limit=8)
77
  def trace_artifact(session_json: str = "{}") -> str:
78
  try:
hackathon_advisor/tool_contracts.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import json
5
+ from typing import Any, Literal
6
+ from xml.etree import ElementTree
7
+
8
+
9
+ JsonType = Literal["string", "integer", "number", "boolean", "array", "object"]
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class ToolField:
14
+ type: JsonType
15
+ description: str
16
+ required: bool = False
17
+ enum: tuple[str, ...] = ()
18
+ items_type: JsonType | None = None
19
+
20
+ def to_schema(self) -> dict[str, Any]:
21
+ schema: dict[str, Any] = {
22
+ "type": self.type,
23
+ "description": self.description,
24
+ }
25
+ if self.enum:
26
+ schema["enum"] = list(self.enum)
27
+ if self.items_type:
28
+ schema["items"] = {"type": self.items_type}
29
+ return schema
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class ToolSpec:
34
+ name: str
35
+ description: str
36
+ fields: dict[str, ToolField]
37
+
38
+ def to_schema(self) -> dict[str, Any]:
39
+ return {
40
+ "type": "function",
41
+ "function": {
42
+ "name": self.name,
43
+ "description": self.description,
44
+ "parameters": {
45
+ "type": "object",
46
+ "additionalProperties": False,
47
+ "properties": {
48
+ name: field.to_schema() for name, field in self.fields.items()
49
+ },
50
+ "required": [name for name, field in self.fields.items() if field.required],
51
+ },
52
+ },
53
+ }
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class ToolCall:
58
+ name: str
59
+ arguments: dict[str, Any]
60
+
61
+ def to_dict(self) -> dict[str, Any]:
62
+ return {"name": self.name, "arguments": self.arguments}
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class ToolResolution:
67
+ status: Literal["valid", "defaulted"]
68
+ call: ToolCall
69
+ errors: tuple[str, ...]
70
+
71
+ def to_dict(self) -> dict[str, Any]:
72
+ return {
73
+ "status": self.status,
74
+ "call": self.call.to_dict(),
75
+ "errors": list(self.errors),
76
+ }
77
+
78
+
79
+ class ToolContractError(ValueError):
80
+ pass
81
+
82
+
83
+ TOOL_SPECS: dict[str, ToolSpec] = {
84
+ "list_projects": ToolSpec(
85
+ name="list_projects",
86
+ description="Read prominent Build Small project cards from the offline snapshot.",
87
+ fields={
88
+ "track": ToolField("string", "Optional prize, badge, model, or topic filter."),
89
+ "sort": ToolField("string", "Sort key.", enum=("likes", "recent", "title")),
90
+ },
91
+ ),
92
+ "search_projects": ToolSpec(
93
+ name="search_projects",
94
+ description="Find existing Spaces that echo the user's project idea.",
95
+ fields={"query": ToolField("string", "The user idea or topic to search.", required=True)},
96
+ ),
97
+ "get_project": ToolSpec(
98
+ name="get_project",
99
+ description="Read one project card by full Space id or slug.",
100
+ fields={"id": ToolField("string", "Project id such as build-small-hackathon/lolaby.", required=True)},
101
+ ),
102
+ "find_whitespace": ToolSpec(
103
+ name="find_whitespace",
104
+ description="Return under-explored project regions from the offline index.",
105
+ fields={},
106
+ ),
107
+ "save_idea": ToolSpec(
108
+ name="save_idea",
109
+ description="Write or update the current idea page.",
110
+ fields={
111
+ "title": ToolField("string", "Short idea title.", required=True),
112
+ "pitch": ToolField("string", "One-sentence idea pitch.", required=True),
113
+ "track": ToolField("string", "Primary target track or award."),
114
+ "models": ToolField("array", "Model ids the idea may use.", items_type="string"),
115
+ "side_quests": ToolField("array", "Badge or side quest targets.", items_type="string"),
116
+ },
117
+ ),
118
+ "score_idea": ToolSpec(
119
+ name="score_idea",
120
+ description="Score the current idea against the fixed hackathon rubric.",
121
+ fields={"id": ToolField("string", "Idea id; omit to score the current idea.")},
122
+ ),
123
+ "compare_ideas": ToolSpec(
124
+ name="compare_ideas",
125
+ description="Rank the current idea board and explain tradeoffs.",
126
+ fields={},
127
+ ),
128
+ "make_plan": ToolSpec(
129
+ name="make_plan",
130
+ description="Draft the next build steps for the current idea.",
131
+ fields={"id": ToolField("string", "Idea id; omit to plan the current idea.")},
132
+ ),
133
+ "update_profile": ToolSpec(
134
+ name="update_profile",
135
+ description="Remember a user skill, constraint, preference, or available time.",
136
+ fields={
137
+ "field": ToolField(
138
+ "string",
139
+ "Profile field to update.",
140
+ required=True,
141
+ enum=("skills", "time", "preferences", "constraints"),
142
+ ),
143
+ "value": ToolField("string", "Profile value to remember.", required=True),
144
+ },
145
+ ),
146
+ "set_target": ToolSpec(
147
+ name="set_target",
148
+ description="Change the badge, model, or award targets used to bias ideation.",
149
+ fields={"side_quests": ToolField("array", "Targets to prioritize.", required=True, items_type="string")},
150
+ ),
151
+ }
152
+
153
+
154
+ def tool_schemas() -> list[dict[str, Any]]:
155
+ return [spec.to_schema() for spec in TOOL_SPECS.values()]
156
+
157
+
158
+ def parse_xml_tool_call(text: str) -> ToolCall:
159
+ wrapped = f"<root>{text.strip()}</root>"
160
+ try:
161
+ root = ElementTree.fromstring(wrapped)
162
+ except ElementTree.ParseError as error:
163
+ raise ToolContractError(f"invalid XML tool call: {error}") from error
164
+
165
+ functions = [node for node in root if node.tag == "function"]
166
+ if len(functions) != 1:
167
+ raise ToolContractError(f"expected exactly one function call, got {len(functions)}")
168
+ node = functions[0]
169
+ name = str(node.attrib.get("name") or "").strip()
170
+ if not name:
171
+ raise ToolContractError("function call is missing a name")
172
+ raw_arguments = (node.text or "").strip() or "{}"
173
+ try:
174
+ arguments = json.loads(raw_arguments)
175
+ except json.JSONDecodeError as error:
176
+ raise ToolContractError(f"function arguments are not valid JSON: {error.msg}") from error
177
+ if not isinstance(arguments, dict):
178
+ raise ToolContractError("function arguments must be a JSON object")
179
+ return ToolCall(name=name, arguments=arguments)
180
+
181
+
182
+ def validate_tool_call(call: ToolCall, specs: dict[str, ToolSpec] = TOOL_SPECS) -> ToolCall:
183
+ spec = specs.get(call.name)
184
+ if spec is None:
185
+ raise ToolContractError(f"unknown tool: {call.name}")
186
+ allowed = set(spec.fields)
187
+ extra = sorted(set(call.arguments) - allowed)
188
+ if extra:
189
+ raise ToolContractError(f"unexpected arguments for {call.name}: {', '.join(extra)}")
190
+ missing = sorted(name for name, field in spec.fields.items() if field.required and name not in call.arguments)
191
+ if missing:
192
+ raise ToolContractError(f"missing required arguments for {call.name}: {', '.join(missing)}")
193
+ for name, value in call.arguments.items():
194
+ field = spec.fields[name]
195
+ _validate_value(call.name, name, value, field)
196
+ return call
197
+
198
+
199
+ def resolve_tool_call(model_output: str, fallback_query: str = "") -> ToolResolution:
200
+ errors: list[str] = []
201
+ try:
202
+ call = validate_tool_call(parse_xml_tool_call(model_output))
203
+ return ToolResolution(status="valid", call=call, errors=())
204
+ except ToolContractError as error:
205
+ errors.append(str(error))
206
+
207
+ query = fallback_query.strip()
208
+ if query:
209
+ call = ToolCall("search_projects", {"query": query})
210
+ else:
211
+ call = ToolCall("find_whitespace", {})
212
+ return ToolResolution(status="defaulted", call=call, errors=tuple(errors))
213
+
214
+
215
+ def _validate_value(tool_name: str, field_name: str, value: Any, field: ToolField) -> None:
216
+ if field.type == "string":
217
+ valid = isinstance(value, str)
218
+ elif field.type == "integer":
219
+ valid = isinstance(value, int) and not isinstance(value, bool)
220
+ elif field.type == "number":
221
+ valid = (isinstance(value, int | float)) and not isinstance(value, bool)
222
+ elif field.type == "boolean":
223
+ valid = isinstance(value, bool)
224
+ elif field.type == "array":
225
+ valid = isinstance(value, list)
226
+ elif field.type == "object":
227
+ valid = isinstance(value, dict)
228
+ else:
229
+ valid = False
230
+ if not valid:
231
+ raise ToolContractError(f"{tool_name}.{field_name} must be {field.type}")
232
+ if field.enum and value not in field.enum:
233
+ raise ToolContractError(f"{tool_name}.{field_name} must be one of: {', '.join(field.enum)}")
234
+ if field.items_type and isinstance(value, list):
235
+ for index, item in enumerate(value):
236
+ _validate_value(tool_name, f"{field_name}[{index}]", item, ToolField(field.items_type, "array item"))
tests/test_app.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
 
3
- from app import bootstrap, engine, health, index, trace_artifact
4
 
5
 
6
  def test_health_exposes_index_metadata() -> None:
@@ -29,3 +29,17 @@ def test_trace_artifact_endpoint_exports_jsonl() -> None:
29
  assert lines[0]["type"] == "trace_manifest"
30
  assert lines[0]["turn_count"] == 1
31
  assert lines[1]["type"] == "agent_turn"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
+ from app import bootstrap, engine, health, index, tool_contract_check, tool_contracts, trace_artifact
4
 
5
 
6
  def test_health_exposes_index_metadata() -> None:
 
29
  assert lines[0]["type"] == "trace_manifest"
30
  assert lines[0]["turn_count"] == 1
31
  assert lines[1]["type"] == "agent_turn"
32
+
33
+
34
+ def test_tool_contracts_endpoint_exposes_schemas() -> None:
35
+ payload = tool_contracts()
36
+
37
+ assert payload["tool_count"] >= 8
38
+ assert any(tool["function"]["name"] == "search_projects" for tool in payload["tools"])
39
+
40
+
41
+ def test_tool_contract_check_endpoint_defaults_safely() -> None:
42
+ payload = tool_contract_check("broken", "family archive")
43
+
44
+ assert payload["status"] == "defaulted"
45
+ assert payload["call"]["name"] == "search_projects"
tests/test_tool_contracts.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from hackathon_advisor.tool_contracts import (
4
+ ToolCall,
5
+ ToolContractError,
6
+ parse_xml_tool_call,
7
+ resolve_tool_call,
8
+ tool_schemas,
9
+ validate_tool_call,
10
+ )
11
+
12
+
13
+ def test_tool_schemas_are_model_ready() -> None:
14
+ schemas = tool_schemas()
15
+
16
+ assert len(schemas) >= 8
17
+ assert schemas[0]["type"] == "function"
18
+ assert {schema["function"]["name"] for schema in schemas} >= {
19
+ "search_projects",
20
+ "find_whitespace",
21
+ "save_idea",
22
+ "score_idea",
23
+ "make_plan",
24
+ }
25
+
26
+
27
+ def test_parse_and_validate_minicpm_xml_tool_call() -> None:
28
+ call = parse_xml_tool_call('<function name="search_projects">{"query":"lullaby audio"}</function>')
29
+
30
+ assert validate_tool_call(call) == ToolCall("search_projects", {"query": "lullaby audio"})
31
+
32
+
33
+ def test_validate_rejects_unknown_tool() -> None:
34
+ with pytest.raises(ToolContractError, match="unknown tool"):
35
+ validate_tool_call(ToolCall("invent_project", {}))
36
+
37
+
38
+ def test_validate_rejects_bad_argument_type() -> None:
39
+ with pytest.raises(ToolContractError, match="search_projects.query must be string"):
40
+ validate_tool_call(ToolCall("search_projects", {"query": 47}))
41
+
42
+
43
+ def test_validate_rejects_extra_arguments() -> None:
44
+ with pytest.raises(ToolContractError, match="unexpected arguments"):
45
+ validate_tool_call(ToolCall("find_whitespace", {"query": "unused"}))
46
+
47
+
48
+ def test_resolve_defaults_to_search_when_output_is_broken() -> None:
49
+ resolution = resolve_tool_call("<function", fallback_query="offline archive")
50
+
51
+ assert resolution.status == "defaulted"
52
+ assert resolution.call == ToolCall("search_projects", {"query": "offline archive"})
53
+ assert resolution.errors
54
+
55
+
56
+ def test_resolve_defaults_to_whitespace_without_query() -> None:
57
+ resolution = resolve_tool_call("no function here")
58
+
59
+ assert resolution.status == "defaulted"
60
+ assert resolution.call == ToolCall("find_whitespace", {})