shazzadulimun commited on
Commit
55e313c
·
verified ·
1 Parent(s): 057dde7

Add live-MCP test code block

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md CHANGED
@@ -49,6 +49,80 @@ Output format is FunctionGemma native:
49
  <start_function_call>call:list_organizations{server:<escape>global<escape>}<end_function_call>
50
  ```
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ## Files
53
 
54
  - `merged_16bit/` — full safetensors checkpoint
 
49
  <start_function_call>call:list_organizations{server:<escape>global<escape>}<end_function_call>
50
  ```
51
 
52
+ ## Live test against the upstream NDP MCP
53
+
54
+ End-to-end: model → tool call → upstream `clio-kit` NDP MCP → real NDP response.
55
+
56
+ ```python
57
+ # /// script
58
+ # requires-python = ">=3.11"
59
+ # dependencies = [
60
+ # "transformers>=4.45", "torch>=2.4", "accelerate>=0.34",
61
+ # "sentencepiece>=0.2", "protobuf>=4", "mcp>=1.0",
62
+ # ]
63
+ # ///
64
+ import asyncio, json, re
65
+ import torch
66
+ from transformers import AutoModelForCausalLM, AutoTokenizer
67
+ from mcp import ClientSession, StdioServerParameters
68
+ from mcp.client.stdio import stdio_client
69
+
70
+ MID = "shazzadulimun/FunctionGemma-ndp"
71
+ PROMPT = "List all organizations on the NDP global server"
72
+
73
+ # 14-tool NDP catalog reshaped as OpenAI function specs (truncated here).
74
+ tools = [{"type": "function", "function": {
75
+ "name": "list_organizations",
76
+ "description": "List organizations available in the National Data Platform.",
77
+ "parameters": {"type": "object", "properties": {
78
+ "name_filter": {"type": "string"}, "server": {"type": "string"},
79
+ }, "required": []},
80
+ }}]
81
+
82
+ tok = AutoTokenizer.from_pretrained(MID, subfolder="merged_16bit")
83
+ mdl = AutoModelForCausalLM.from_pretrained(
84
+ MID, subfolder="merged_16bit", dtype=torch.bfloat16, device_map="auto",
85
+ )
86
+ text = tok.apply_chat_template(
87
+ [{"role": "user", "content": PROMPT}],
88
+ tools=tools, add_generation_prompt=True, tokenize=False,
89
+ )
90
+ inp = tok(text, return_tensors="pt").to(mdl.device)
91
+ out = mdl.generate(**inp, max_new_tokens=300)
92
+ raw = tok.decode(out[0][inp.input_ids.shape[-1]:], skip_special_tokens=False)
93
+
94
+ # Parse FunctionGemma format: <start_function_call>call:NAME{k:v,...}<end_function_call>
95
+ m = re.search(r"<start_function_call>\s*call:(\w+)\s*\{(.*?)\}\s*<end_function_call>",
96
+ raw, re.DOTALL)
97
+ name = m.group(1)
98
+ args = {}
99
+ for k, v in re.findall(r"(\w+)\s*:\s*(<escape>.*?<escape>|None|\w+)", m.group(2)):
100
+ if v == "None":
101
+ continue # strip phantom nulls
102
+ args[k] = re.sub(r"<escape>|<escape>", "", v) if "<escape>" in v else v
103
+
104
+ # Spawn the upstream clio-kit NDP MCP and call the parsed tool against it.
105
+ async def call():
106
+ params = StdioServerParameters(command="uvx", args=[
107
+ "--from",
108
+ "git+https://github.com/iowarp/clio-kit.git#subdirectory=clio-kit-mcp-servers/ndp",
109
+ "ndp-mcp",
110
+ ])
111
+ async with stdio_client(params) as (r, w):
112
+ async with ClientSession(r, w) as s:
113
+ await s.initialize()
114
+ out = await s.call_tool(name, args)
115
+ print("".join(c.text for c in out.content if hasattr(c, "text")))
116
+
117
+ asyncio.run(call())
118
+ ```
119
+
120
+ Save as `test.py` and run:
121
+
122
+ ```bash
123
+ uv run --isolated test.py
124
+ ```
125
+
126
  ## Files
127
 
128
  - `merged_16bit/` — full safetensors checkpoint