File size: 4,205 Bytes
6fa8b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e5cf62
6fa8b70
6e5cf62
 
057dde7
6fa8b70
057dde7
 
6fa8b70
 
 
 
 
6e5cf62
 
6fa8b70
057dde7
 
 
6fa8b70
 
057dde7
 
6e5cf62
6fa8b70
057dde7
 
 
6fa8b70
 
6e5cf62
 
 
057dde7
6e5cf62
 
55e313c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
057dde7
6e5cf62
057dde7
 
6fa8b70
 
 
057dde7
 
 
6fa8b70
057dde7
6e5cf62
 
 
057dde7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
base_model: unsloth/functiongemma-270m-it
license: apache-2.0
language: en
library_name: transformers
tags:
- ndp
- tool-calling
- function-calling
- mcp
- unsloth
- lora
- functiongemma
---

# FunctionGemma-ndp

A 270M FunctionGemma fine-tune for tool-calling against the
[National Data Platform (NDP)](https://nationaldataplatform.org/) MCP
server.

Supports three tools: `list_organizations`, `search_datasets`,
`get_dataset_details`.

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

mid = "shazzadulimun/FunctionGemma-ndp"
tok = AutoTokenizer.from_pretrained(mid, subfolder="merged_16bit")
mdl = AutoModelForCausalLM.from_pretrained(
    mid, subfolder="merged_16bit", device_map="auto",
)

messages = [
    {"role": "developer", "content":
     "You are a model that can do function calling with the following functions"},
    {"role": "user", "content": "List all organizations on the NDP global server"},
]
prompt = tok.apply_chat_template(
    messages, tools=[...], add_generation_prompt=True, tokenize=False,
)
```

Output format is FunctionGemma native:

```
<start_function_call>call:list_organizations{server:<escape>global<escape>}<end_function_call>
```

## Live test against the upstream NDP MCP

End-to-end: model → tool call → upstream `clio-kit` NDP MCP → real NDP response.

```python
# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "transformers>=4.45", "torch>=2.4", "accelerate>=0.34",
#   "sentencepiece>=0.2", "protobuf>=4", "mcp>=1.0",
# ]
# ///
import asyncio, json, re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

MID = "shazzadulimun/FunctionGemma-ndp"
PROMPT = "List all organizations on the NDP global server"

# 14-tool NDP catalog reshaped as OpenAI function specs (truncated here).
tools = [{"type": "function", "function": {
    "name": "list_organizations",
    "description": "List organizations available in the National Data Platform.",
    "parameters": {"type": "object", "properties": {
        "name_filter": {"type": "string"}, "server": {"type": "string"},
    }, "required": []},
}}]

tok = AutoTokenizer.from_pretrained(MID, subfolder="merged_16bit")
mdl = AutoModelForCausalLM.from_pretrained(
    MID, subfolder="merged_16bit", dtype=torch.bfloat16, device_map="auto",
)
text = tok.apply_chat_template(
    [{"role": "user", "content": PROMPT}],
    tools=tools, add_generation_prompt=True, tokenize=False,
)
inp = tok(text, return_tensors="pt").to(mdl.device)
out = mdl.generate(**inp, max_new_tokens=300)
raw = tok.decode(out[0][inp.input_ids.shape[-1]:], skip_special_tokens=False)

# Parse FunctionGemma format: <start_function_call>call:NAME{k:v,...}<end_function_call>
m = re.search(r"<start_function_call>\s*call:(\w+)\s*\{(.*?)\}\s*<end_function_call>",
              raw, re.DOTALL)
name = m.group(1)
args = {}
for k, v in re.findall(r"(\w+)\s*:\s*(<escape>.*?<escape>|None|\w+)", m.group(2)):
    if v == "None":
        continue                                      # strip phantom nulls
    args[k] = re.sub(r"<escape>|<escape>", "", v) if "<escape>" in v else v

# Spawn the upstream clio-kit NDP MCP and call the parsed tool against it.
async def call():
    params = StdioServerParameters(command="uvx", args=[
        "--from",
        "git+https://github.com/iowarp/clio-kit.git#subdirectory=clio-kit-mcp-servers/ndp",
        "ndp-mcp",
    ])
    async with stdio_client(params) as (r, w):
        async with ClientSession(r, w) as s:
            await s.initialize()
            out = await s.call_tool(name, args)
            print("".join(c.text for c in out.content if hasattr(c, "text")))

asyncio.run(call())
```

Save as `test.py` and run:

```bash
uv run --isolated test.py
```

## Files

- `merged_16bit/` — full safetensors checkpoint
- `lora/` — LoRA adapter only

## Training

- Base model: `unsloth/functiongemma-270m-it`
- Method: LoRA (r=64, alpha=128)
- Hyperparameters: 3 epochs, batch size 16, learning rate 2e-4
- Train loss: 0.25
- Hardware: 1× H200

## Citation

Built with [Phagocyte](https://github.com/grc-iit/Phagocyte).