Upload 13 files
Browse files- .gitattributes +1 -0
- LICENSE +23 -0
- README.md +112 -962
- data/benchmark_dataset.json +0 -0
- data/training_data.json +3 -0
- requirements.txt +36 -0
- results/evaluation_results.json +0 -0
- run_training.sh +73 -0
- src/__init__.py +0 -0
- src/config.py +317 -0
- src/evaluate.py +641 -0
- src/generate_benchmark.py +459 -0
- src/prepare_dataset.py +199 -0
- src/train.py +559 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/training_data.json filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 FunctionGemma contributors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
README.md
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
- zh
|
|
@@ -19,1007 +27,149 @@ tags:
|
|
| 19 |
- standard-protocol
|
| 20 |
library_name: transformers
|
| 21 |
pipeline_tag: text-generation
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# DMind-3-nano: Privacy-First On-Device Crypto Intent Recognition
|
| 25 |
-
|
| 26 |
-
<div align="center">
|
| 27 |
-
|
| 28 |
-
[](https://huggingface.co/YOUR_ORG/DMind-3-nano)
|
| 29 |
-
[](https://opensource.org/licenses/Apache-2.0)
|
| 30 |
-
[](https://huggingface.co/google/functiongemma-270m-it)
|
| 31 |
-
[](https://huggingface.co/YOUR_ORG/DMind-3-nano)
|
| 32 |
-
[](https://huggingface.co/YOUR_ORG/DMind-3-nano)
|
| 33 |
-
|
| 34 |
-
**🔐 Your Keys. Your Data. Your Privacy.**
|
| 35 |
-
|
| 36 |
-
</div>
|
| 37 |
-
|
| 38 |
-
## 🎯 Mission: Privacy-First Local Wallet Intelligence
|
| 39 |
-
|
| 40 |
-
**DMind-3-nano** is designed for **on-device intent recognition** in cryptocurrency wallets, prioritizing user privacy through local inference. It establishes standardized protocols for blockchain function calling while keeping all user interactions private and secure.
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
-
🔐
|
| 45 |
-
📱
|
| 46 |
-
🔄
|
| 47 |
-
|
| 48 |
-
🌍
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
|
| 54 |
-
|
| 55 |
-
- ✅ Local AI: All intent recognition happens on your device—**your keys, your data, your privacy**
|
| 56 |
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
| **Privacy** | Logged & analyzed | 100% private |
|
| 71 |
-
| **Latency** | 500-2000ms | <100ms |
|
| 72 |
-
| **Works Offline** | ❌ No | ✅ Yes |
|
| 73 |
-
| **Monthly Cost** | $5-20/user | Free |
|
| 74 |
-
| **Data Breach Risk** | High (centralized) | Zero (no cloud) |
|
| 75 |
-
| **Model Size** | N/A (cloud) | 70-540MB |
|
| 76 |
|
| 77 |
-
|
| 78 |
|
| 79 |
-
|
| 80 |
-
- 📱 **Edge-Optimized**: Only 270M parameters, runs on phones and tablets
|
| 81 |
-
- ⚡ **Fast Inference**: <100ms response time on modern mobile CPUs
|
| 82 |
-
- 🎯 **Protocol-Standardized**: Cross-platform compatibility through unified schemas
|
| 83 |
-
- 🌐 **Multi-Chain**: Solana, Ethereum, BSC, Base support
|
| 84 |
-
- 🌍 **Multilingual**: Native English and Chinese understanding
|
| 85 |
-
- 🔋 **Energy-Efficient**: Minimal battery impact for mobile wallets
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
| Property | Value |
|
| 90 |
-
|
| 91 |
-
| Model
|
| 92 |
-
| Base
|
| 93 |
-
|
|
| 94 |
-
| Context
|
| 95 |
-
| Precision | BF16 |
|
| 96 |
-
|
|
| 97 |
-
|
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
**
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
```
|
| 107 |
|
| 108 |
-
|
| 109 |
-
```
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
```
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
- ✅ **Optimal**: The model achieves highest accuracy (95%+) on the above tokens and chains
|
| 115 |
-
- ⚠️ **General Support**: Other tokens/chains are supported but may have lower accuracy
|
| 116 |
-
- 🔄 **Extensible**: The protocol design allows easy fine-tuning for additional tokens/chains
|
| 117 |
-
|
| 118 |
-
For production deployments, we recommend:
|
| 119 |
-
1. Testing with your specific token list
|
| 120 |
-
2. Fine-tuning on your target tokens if needed
|
| 121 |
-
3. Validating outputs before executing transactions
|
| 122 |
-
|
| 123 |
-
---
|
| 124 |
-
|
| 125 |
-
## 🚀 Quick Start
|
| 126 |
-
|
| 127 |
-
### Installation
|
| 128 |
-
|
| 129 |
```bash
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
```
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Load model and processor
|
| 140 |
-
model_name = "YOUR_ORG/DMind-3-nano"
|
| 141 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
| 142 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 143 |
-
model_name,
|
| 144 |
-
torch_dtype=torch.bfloat16,
|
| 145 |
-
device_map="auto"
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# Define available tools
|
| 149 |
-
tools = [
|
| 150 |
-
{
|
| 151 |
-
"type": "function",
|
| 152 |
-
"function": {
|
| 153 |
-
"name": "SEARCH_TOKEN",
|
| 154 |
-
"description": "Search for tokens on blockchain",
|
| 155 |
-
"parameters": {
|
| 156 |
-
"type": "object",
|
| 157 |
-
"properties": {
|
| 158 |
-
"symbol": {"type": "string", "description": "Token symbol"},
|
| 159 |
-
"chain": {"type": "string", "enum": ["solana", "ethereum", "bsc", "base"]}
|
| 160 |
-
}
|
| 161 |
-
}
|
| 162 |
-
}
|
| 163 |
-
},
|
| 164 |
-
{
|
| 165 |
-
"type": "function",
|
| 166 |
-
"function": {
|
| 167 |
-
"name": "EXECUTE_SWAP",
|
| 168 |
-
"description": "Execute token swap on Solana",
|
| 169 |
-
"parameters": {
|
| 170 |
-
"type": "object",
|
| 171 |
-
"properties": {
|
| 172 |
-
"inputTokenSymbol": {"type": "string"},
|
| 173 |
-
"outputTokenSymbol": {"type": "string"},
|
| 174 |
-
"inputTokenAmount": {"type": "string"}
|
| 175 |
-
}
|
| 176 |
-
}
|
| 177 |
-
}
|
| 178 |
-
}
|
| 179 |
-
]
|
| 180 |
-
|
| 181 |
-
# Prepare conversation
|
| 182 |
-
messages = [
|
| 183 |
-
{"role": "developer", "content": "You are a helpful assistant for crypto operations."},
|
| 184 |
-
{"role": "user", "content": "Buy 100 BONK with SOL"}
|
| 185 |
-
]
|
| 186 |
-
|
| 187 |
-
# Generate function call
|
| 188 |
-
inputs = processor.apply_chat_template(
|
| 189 |
-
messages,
|
| 190 |
-
tools=tools,
|
| 191 |
-
add_generation_prompt=True,
|
| 192 |
-
return_dict=True,
|
| 193 |
-
return_tensors="pt"
|
| 194 |
-
).to(model.device)
|
| 195 |
-
|
| 196 |
-
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
|
| 197 |
-
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 198 |
-
print(response)
|
| 199 |
-
# Output: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"100"}<end_function_call>
|
| 200 |
```
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
## 🛠️ Tool Protocols (Standardized)
|
| 205 |
|
| 206 |
-
|
| 207 |
|
| 208 |
-
|
|
|
|
| 209 |
|
| 210 |
-
**
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
"description": "Search for tokens on blockchain by symbol, address, or keyword",
|
| 216 |
-
"parameters": {
|
| 217 |
-
"type": "object",
|
| 218 |
-
"properties": {
|
| 219 |
-
"symbol": {
|
| 220 |
-
"type": "string",
|
| 221 |
-
"description": "Token symbol (e.g., 'SOL', 'ETH')"
|
| 222 |
-
},
|
| 223 |
-
"address": {
|
| 224 |
-
"type": "string",
|
| 225 |
-
"description": "Contract address of the token"
|
| 226 |
-
},
|
| 227 |
-
"chain": {
|
| 228 |
-
"type": "string",
|
| 229 |
-
"enum": ["solana", "ethereum", "bsc", "base"],
|
| 230 |
-
"description": "Target blockchain network"
|
| 231 |
-
},
|
| 232 |
-
"keyword": {
|
| 233 |
-
"type": "string",
|
| 234 |
-
"description": "Search keyword for token discovery"
|
| 235 |
-
}
|
| 236 |
-
},
|
| 237 |
-
"required": []
|
| 238 |
-
}
|
| 239 |
-
}
|
| 240 |
```
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
```
|
| 245 |
User: "查一下 Solana 上的 SOL"
|
| 246 |
Model: <start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>
|
| 247 |
```
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
Execute token swaps on Solana blockchain with intelligent defaults.
|
| 252 |
-
|
| 253 |
-
**Schema:**
|
| 254 |
-
|
| 255 |
-
```json
|
| 256 |
-
{
|
| 257 |
-
"name": "EXECUTE_SWAP",
|
| 258 |
-
"description": "Swap tokens on the Solana blockchain. When user specifies 'buy <token>', default input is SOL. When 'sell <token>', default output is SOL.",
|
| 259 |
-
"parameters": {
|
| 260 |
-
"type": "object",
|
| 261 |
-
"properties": {
|
| 262 |
-
"inputTokenSymbol": {
|
| 263 |
-
"type": "string",
|
| 264 |
-
"description": "Symbol of the token to sell"
|
| 265 |
-
},
|
| 266 |
-
"inputTokenCA": {
|
| 267 |
-
"type": "string",
|
| 268 |
-
"description": "Contract address of input token"
|
| 269 |
-
},
|
| 270 |
-
"outputTokenSymbol": {
|
| 271 |
-
"type": "string",
|
| 272 |
-
"description": "Symbol of the token to buy"
|
| 273 |
-
},
|
| 274 |
-
"outputTokenCA": {
|
| 275 |
-
"type": "string",
|
| 276 |
-
"description": "Contract address of output token"
|
| 277 |
-
},
|
| 278 |
-
"inputTokenAmount": {
|
| 279 |
-
"type": "string",
|
| 280 |
-
"description": "Exact amount of input token"
|
| 281 |
-
},
|
| 282 |
-
"inputTokenPercentage": {
|
| 283 |
-
"type": "number",
|
| 284 |
-
"description": "Percentage of input token balance (0.0-1.0)"
|
| 285 |
-
},
|
| 286 |
-
"outputTokenAmount": {
|
| 287 |
-
"type": "string",
|
| 288 |
-
"description": "Expected output token amount"
|
| 289 |
-
}
|
| 290 |
-
},
|
| 291 |
-
"required": []
|
| 292 |
-
}
|
| 293 |
-
}
|
| 294 |
-
```
|
| 295 |
-
|
| 296 |
-
**Example Usage:**
|
| 297 |
-
|
| 298 |
-
```
|
| 299 |
-
User: "Swap 5 SOL to BONK"
|
| 300 |
-
Model: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"5"}<end_function_call>
|
| 301 |
-
|
| 302 |
-
User: "Sell 50% of my WIF"
|
| 303 |
-
Model: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"WIF",outputTokenSymbol:"SOL",inputTokenPercentage:0.5}<end_function_call>
|
| 304 |
-
```
|
| 305 |
-
|
| 306 |
-
### Output Format Convention
|
| 307 |
-
|
| 308 |
-
- **Wrapper Tags**: `<start_function_call>call:FUNCTION_NAME{args}<end_function_call>`
|
| 309 |
-
- **Escape Sequences**: Use `<escape>` for quoted strings when necessary
|
| 310 |
-
- **Argument Format**: JSON-like key-value pairs
|
| 311 |
-
|
| 312 |
-
---
|
| 313 |
-
|
| 314 |
-
## 🔐 Privacy-First Architecture
|
| 315 |
-
|
| 316 |
-
### The Privacy Problem in Crypto AI
|
| 317 |
-
|
| 318 |
-
Most AI-powered wallet assistants today send your data to cloud servers:
|
| 319 |
-
|
| 320 |
-
```
|
| 321 |
-
❌ Cloud-based Assistant:
|
| 322 |
-
User: "Buy 1000 BONK" → Sent to Cloud API → Privacy Risk
|
| 323 |
-
• Your wallet commands are logged
|
| 324 |
-
• Trading patterns analyzed by third parties
|
| 325 |
-
• Potential data breaches expose user behavior
|
| 326 |
-
• Latency and internet dependency
|
| 327 |
-
```
|
| 328 |
-
|
| 329 |
-
### Our Solution: On-Device Intelligence
|
| 330 |
-
|
| 331 |
-
**DMind-3-nano** runs entirely on your device:
|
| 332 |
-
|
| 333 |
-
```
|
| 334 |
-
✅ On-Device Assistant:
|
| 335 |
-
User: "Buy 1000 BONK" → Local Processing → Complete Privacy
|
| 336 |
-
• Zero data leaves your device
|
| 337 |
-
• No tracking, no logging, no surveillance
|
| 338 |
-
• Works offline
|
| 339 |
-
• Instant response (<100ms)
|
| 340 |
-
```
|
| 341 |
-
|
| 342 |
-
### Why Edge Deployment?
|
| 343 |
-
|
| 344 |
-
| Feature | Cloud AI | DMind-3-nano (On-Device) |
|
| 345 |
-
|---------|----------|--------------------------|
|
| 346 |
-
| **Privacy** | ❌ Data sent to servers | ✅ 100% local processing |
|
| 347 |
-
| **Latency** | ~500-2000ms | ✅ <100ms |
|
| 348 |
-
| **Offline** | ❌ Requires internet | ✅ Works offline |
|
| 349 |
-
| **Cost** | 💰 API fees per request | ✅ Free after download |
|
| 350 |
-
| **Security** | ⚠️ API keys, data breaches | ✅ No attack surface |
|
| 351 |
-
| **Compliance** | ⚠️ Data sovereignty issues | ✅ Full user control |
|
| 352 |
-
|
| 353 |
-
### Real-World Deployment Scenarios
|
| 354 |
-
|
| 355 |
-
**📱 Mobile Wallets**: iOS/Android wallet apps with conversational interface
|
| 356 |
-
**💻 Desktop Wallets**: Electron-based wallets like Phantom, MetaMask
|
| 357 |
-
**🔐 Hardware Wallets**: Secure enclaves with minimal compute (via quantization)
|
| 358 |
-
**🌐 Browser Extensions**: Chrome/Firefox wallet extensions
|
| 359 |
-
**🖥️ Local Trading Apps**: Desktop trading terminals with AI assistance
|
| 360 |
-
|
| 361 |
-
---
|
| 362 |
-
|
| 363 |
-
## 🌟 Why Standardized Protocols Matter
|
| 364 |
-
|
| 365 |
-
### The Problem
|
| 366 |
-
|
| 367 |
-
In the current crypto AI landscape, each agent system implements its own function-calling format:
|
| 368 |
-
|
| 369 |
-
```
|
| 370 |
-
❌ Agent A: {"action": "swap", "from": "SOL", "to": "BONK", "amount": 100}
|
| 371 |
-
❌ Agent B: {"tool": "trade", "input_token": "SOL", "output_token": "BONK", "qty": 100}
|
| 372 |
-
❌ Agent C: {"cmd": "exchange", "sell": "SOL", "buy": "BONK", "vol": 100}
|
| 373 |
-
```
|
| 374 |
-
|
| 375 |
-
**Result**: Fragmentation, incompatibility, and integration hell.
|
| 376 |
-
|
| 377 |
-
### Our Solution
|
| 378 |
-
|
| 379 |
-
**Unified protocol** with clear schema definitions:
|
| 380 |
-
|
| 381 |
-
```
|
| 382 |
-
✅ Standardized: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"100"}<end_function_call>
|
| 383 |
-
```
|
| 384 |
-
|
| 385 |
-
**Benefits**:
|
| 386 |
-
- 🔄 **Plug-and-Play**: One protocol, many platforms
|
| 387 |
-
- 🤝 **Multi-Agent Coordination**: Different agents speak the same language
|
| 388 |
-
- 🛡️ **Type Safety**: Well-defined schemas reduce errors
|
| 389 |
-
- 📈 **Ecosystem Growth**: Lower barrier for new integrations
|
| 390 |
-
|
| 391 |
-
---
|
| 392 |
-
|
| 393 |
-
## 📊 Performance & Reliability
|
| 394 |
-
|
| 395 |
-
Validated through extensive real-world testing:
|
| 396 |
-
|
| 397 |
-
| Metric | Result |
|
| 398 |
-
|--------|--------|
|
| 399 |
-
| Test Cases | 10,000+ real user interactions |
|
| 400 |
-
| Function Recognition Accuracy | **96.8%** |
|
| 401 |
-
| Parameter Extraction Accuracy | **94.2%** |
|
| 402 |
-
| SEARCH_TOKEN Protocol Adherence | **98.1%** |
|
| 403 |
-
| EXECUTE_SWAP Protocol Adherence | **95.3%** |
|
| 404 |
-
| Multi-turn Conversation Success | **92.7%** |
|
| 405 |
-
|
| 406 |
-
**Testing Methodology**: Production environment with real users performing actual cryptocurrency operations across Solana, Ethereum, BSC, and Base chains. All tests conducted with standardized protocol validation.
|
| 407 |
-
|
| 408 |
-
**Testing Scope**: The above metrics are based on the validated token set (SOL, USDC, JUP, RAY, BONK, WIF, ETH, BTC, POPCAT, BOME, TRUMP) and supported chains (Solana, Ethereum, BSC, Base). Performance on other tokens may vary.
|
| 409 |
-
|
| 410 |
-
---
|
| 411 |
-
|
| 412 |
-
## 🔧 Advanced Usage
|
| 413 |
-
|
| 414 |
-
### Batch Processing
|
| 415 |
-
|
| 416 |
-
Process multiple queries efficiently:
|
| 417 |
-
|
| 418 |
-
```python
|
| 419 |
-
queries = [
|
| 420 |
-
"Buy 10 SOL worth of BONK",
|
| 421 |
-
"Search for Ethereum tokens",
|
| 422 |
-
"Sell 50% of my WIF holdings"
|
| 423 |
-
]
|
| 424 |
-
|
| 425 |
-
for query in queries:
|
| 426 |
-
messages = [
|
| 427 |
-
{"role": "developer", "content": "You are a helpful assistant for crypto operations."},
|
| 428 |
-
{"role": "user", "content": query}
|
| 429 |
-
]
|
| 430 |
-
|
| 431 |
-
inputs = processor.apply_chat_template(
|
| 432 |
-
messages, tools=tools,
|
| 433 |
-
add_generation_prompt=True,
|
| 434 |
-
return_dict=True,
|
| 435 |
-
return_tensors="pt"
|
| 436 |
-
).to(model.device)
|
| 437 |
-
|
| 438 |
-
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
|
| 439 |
-
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 440 |
-
print(f"Query: {query}")
|
| 441 |
-
print(f"Response: {response}\n")
|
| 442 |
-
```
|
| 443 |
-
|
| 444 |
-
### Parsing Function Calls
|
| 445 |
-
|
| 446 |
-
Example parser for the standardized output format:
|
| 447 |
-
|
| 448 |
-
```python
|
| 449 |
-
import re
|
| 450 |
-
import json
|
| 451 |
-
|
| 452 |
-
def parse_function_call(response: str) -> dict:
|
| 453 |
-
"""Parse standardized function call format"""
|
| 454 |
-
pattern = r'<start_function_call>call:(\w+)\{(.+?)\}<end_function_call>'
|
| 455 |
-
match = re.search(pattern, response)
|
| 456 |
-
|
| 457 |
-
if match:
|
| 458 |
-
func_name = match.group(1)
|
| 459 |
-
args_str = match.group(2).replace('<escape>', '"')
|
| 460 |
-
|
| 461 |
-
# Parse arguments
|
| 462 |
-
args = {}
|
| 463 |
-
for item in args_str.split(','):
|
| 464 |
-
if ':' in item:
|
| 465 |
-
key, value = item.split(':', 1)
|
| 466 |
-
args[key.strip().strip('"')] = value.strip().strip('"')
|
| 467 |
-
|
| 468 |
-
return {
|
| 469 |
-
"function": func_name,
|
| 470 |
-
"arguments": args,
|
| 471 |
-
"valid": True
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
-
return {"valid": False, "response": response}
|
| 475 |
-
|
| 476 |
-
# Usage
|
| 477 |
-
response = '<start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>'
|
| 478 |
-
parsed = parse_function_call(response)
|
| 479 |
-
print(parsed)
|
| 480 |
-
# Output: {'function': 'SEARCH_TOKEN', 'arguments': {'symbol': 'SOL', 'chain': 'solana'}, 'valid': True}
|
| 481 |
-
```
|
| 482 |
-
|
| 483 |
-
---
|
| 484 |
-
|
| 485 |
-
## 🏗️ Integration Guide
|
| 486 |
-
|
| 487 |
-
### For Desktop Wallet Developers
|
| 488 |
-
|
| 489 |
-
Integrate DMind-3-nano into your Python-based wallet application:
|
| 490 |
-
|
| 491 |
-
```python
|
| 492 |
-
class PrivateWalletAgent:
|
| 493 |
-
"""On-device wallet assistant with zero cloud dependency"""
|
| 494 |
-
|
| 495 |
-
def __init__(self, model_name="YOUR_ORG/DMind-3-nano"):
|
| 496 |
-
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 497 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
| 498 |
-
model_name,
|
| 499 |
-
torch_dtype=torch.bfloat16,
|
| 500 |
-
device_map="cpu" # Force CPU for consistent behavior
|
| 501 |
-
)
|
| 502 |
-
self.tools = self._load_tools()
|
| 503 |
-
|
| 504 |
-
def process_user_command(self, user_query: str) -> dict:
|
| 505 |
-
"""Process user intent locally, no network required"""
|
| 506 |
-
messages = [
|
| 507 |
-
{"role": "developer", "content": "You are a private wallet assistant."},
|
| 508 |
-
{"role": "user", "content": user_query}
|
| 509 |
-
]
|
| 510 |
-
|
| 511 |
-
inputs = self.processor.apply_chat_template(
|
| 512 |
-
messages, tools=self.tools,
|
| 513 |
-
add_generation_prompt=True,
|
| 514 |
-
return_dict=True, return_tensors="pt"
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
# Local inference only
|
| 518 |
-
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=False)
|
| 519 |
-
response = self.processor.decode(
|
| 520 |
-
outputs[0][inputs["input_ids"].shape[1]:],
|
| 521 |
-
skip_special_tokens=True
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
return self.parse_and_execute(response)
|
| 525 |
-
|
| 526 |
-
def parse_and_execute(self, response: str) -> dict:
|
| 527 |
-
parsed = parse_function_call(response)
|
| 528 |
-
|
| 529 |
-
if parsed["valid"]:
|
| 530 |
-
# Execute locally - never send to cloud
|
| 531 |
-
if parsed["function"] == "SEARCH_TOKEN":
|
| 532 |
-
return self.local_token_search(**parsed["arguments"])
|
| 533 |
-
elif parsed["function"] == "EXECUTE_SWAP":
|
| 534 |
-
return self.prepare_swap_transaction(**parsed["arguments"])
|
| 535 |
-
|
| 536 |
-
return {"error": "Invalid function call"}
|
| 537 |
-
```
|
| 538 |
-
|
| 539 |
-
### For Mobile Wallet Developers
|
| 540 |
-
|
| 541 |
-
#### iOS (Swift + CoreML)
|
| 542 |
-
|
| 543 |
-
```swift
|
| 544 |
-
import CoreML
|
| 545 |
-
|
| 546 |
-
class WalletIntentRecognizer {
|
| 547 |
-
private var model: MLModel
|
| 548 |
-
|
| 549 |
-
init() {
|
| 550 |
-
// Convert DMind-3-nano to CoreML format first
|
| 551 |
-
// See: https://huggingface.co/docs/transformers/serialization
|
| 552 |
-
guard let modelURL = Bundle.main.url(forResource: "DMind3Nano", withExtension: "mlmodelc"),
|
| 553 |
-
let model = try? MLModel(contentsOf: modelURL) else {
|
| 554 |
-
fatalError("Failed to load model")
|
| 555 |
-
}
|
| 556 |
-
self.model = model
|
| 557 |
-
}
|
| 558 |
-
|
| 559 |
-
func recognizeIntent(userInput: String) -> TransactionIntent? {
|
| 560 |
-
// Tokenize and run inference locally
|
| 561 |
-
let prediction = try? model.prediction(from: /* input features */)
|
| 562 |
-
|
| 563 |
-
// Parse standardized output
|
| 564 |
-
if let functionCall = parseFunctionCall(prediction?.output) {
|
| 565 |
-
return TransactionIntent(
|
| 566 |
-
function: functionCall.name,
|
| 567 |
-
parameters: functionCall.args
|
| 568 |
-
)
|
| 569 |
-
}
|
| 570 |
-
return nil
|
| 571 |
-
}
|
| 572 |
-
}
|
| 573 |
-
```
|
| 574 |
-
|
| 575 |
-
#### Android (Kotlin + ONNX Runtime)
|
| 576 |
-
|
| 577 |
-
```kotlin
|
| 578 |
-
import ai.onnxruntime.*
|
| 579 |
-
|
| 580 |
-
class WalletIntentRecognizer(context: Context) {
|
| 581 |
-
private val ortSession: OrtSession
|
| 582 |
-
|
| 583 |
-
init {
|
| 584 |
-
val ortEnv = OrtEnvironment.getEnvironment()
|
| 585 |
-
val modelBytes = context.assets.open("dmind3nano.onnx").readBytes()
|
| 586 |
-
ortSession = ortEnv.createSession(modelBytes)
|
| 587 |
-
}
|
| 588 |
-
|
| 589 |
-
fun recognizeIntent(userInput: String): TransactionIntent? {
|
| 590 |
-
// Tokenize input
|
| 591 |
-
val inputTensor = tokenizeInput(userInput)
|
| 592 |
-
|
| 593 |
-
// Run inference on-device
|
| 594 |
-
val outputs = ortSession.run(mapOf("input_ids" to inputTensor))
|
| 595 |
-
|
| 596 |
-
// Parse standardized output
|
| 597 |
-
return parseFunctionCall(outputs)
|
| 598 |
-
}
|
| 599 |
-
|
| 600 |
-
private fun parseFunctionCall(output: OrtSession.Result): TransactionIntent? {
|
| 601 |
-
// Parse <start_function_call>call:FUNCTION{args}<end_function_call>
|
| 602 |
-
// Return structured intent for wallet execution
|
| 603 |
-
}
|
| 604 |
-
}
|
| 605 |
-
```
|
| 606 |
-
|
| 607 |
-
### Model Conversion Guide
|
| 608 |
-
|
| 609 |
-
#### Convert to ONNX (for Android/Cross-platform)
|
| 610 |
-
|
| 611 |
-
```python
|
| 612 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 613 |
-
import torch
|
| 614 |
-
|
| 615 |
-
model = AutoModelForCausalLM.from_pretrained("YOUR_ORG/DMind-3-nano")
|
| 616 |
-
processor = AutoProcessor.from_pretrained("YOUR_ORG/DMind-3-nano")
|
| 617 |
-
|
| 618 |
-
# Export to ONNX
|
| 619 |
-
dummy_input = processor("test", return_tensors="pt")
|
| 620 |
-
torch.onnx.export(
|
| 621 |
-
model,
|
| 622 |
-
(dummy_input["input_ids"],),
|
| 623 |
-
"dmind3nano.onnx",
|
| 624 |
-
input_names=["input_ids"],
|
| 625 |
-
output_names=["logits"],
|
| 626 |
-
dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}}
|
| 627 |
-
)
|
| 628 |
-
|
| 629 |
-
# Quantize to INT8 for smaller size (~70MB)
|
| 630 |
-
from onnxruntime.quantization import quantize_dynamic
|
| 631 |
-
quantize_dynamic("dmind3nano.onnx", "dmind3nano_int8.onnx")
|
| 632 |
-
```
|
| 633 |
-
|
| 634 |
-
#### Convert to CoreML (for iOS)
|
| 635 |
-
|
| 636 |
-
```python
|
| 637 |
-
import coremltools as ct
|
| 638 |
-
from transformers import AutoModelForCausalLM
|
| 639 |
-
|
| 640 |
-
model = AutoModelForCausalLM.from_pretrained("YOUR_ORG/DMind-3-nano")
|
| 641 |
-
|
| 642 |
-
# Trace model
|
| 643 |
-
traced_model = torch.jit.trace(model, dummy_input["input_ids"])
|
| 644 |
-
|
| 645 |
-
# Convert to CoreML
|
| 646 |
-
coreml_model = ct.convert(
|
| 647 |
-
traced_model,
|
| 648 |
-
inputs=[ct.TensorType(shape=(1, ct.RangeDim(1, 512)), dtype=np.int32)]
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
coreml_model.save("DMind3Nano.mlmodel")
|
| 652 |
-
```
|
| 653 |
-
|
| 654 |
-
### For Protocol Adopters
|
| 655 |
-
|
| 656 |
-
If you're building your own wallet or AI model, adopt these privacy-first protocols:
|
| 657 |
-
|
| 658 |
-
#### Protocol Adoption Checklist
|
| 659 |
-
|
| 660 |
-
1. ✅ **Use exact schema definitions** for `SEARCH_TOKEN` and `EXECUTE_SWAP`
|
| 661 |
-
2. ✅ **Follow standardized output format**: `<start_function_call>call:FUNCTION_NAME{args}<end_function_call>`
|
| 662 |
-
3. ✅ **Validate outputs** against the protocol schemas
|
| 663 |
-
4. ✅ **Support on-device inference** (no cloud dependencies)
|
| 664 |
-
5. ✅ **Document your implementation** and share with the community
|
| 665 |
-
|
| 666 |
-
#### Why Adopt These Protocols?
|
| 667 |
-
|
| 668 |
-
**For Users:**
|
| 669 |
-
- 🔐 Privacy guarantee: Works identically across all compliant wallets
|
| 670 |
-
- 🔄 Portability: Same commands work everywhere
|
| 671 |
-
- 🛡️ Security: Standardized validation reduces vulnerabilities
|
| 672 |
-
|
| 673 |
-
**For Developers:**
|
| 674 |
-
- ⚡ Instant compatibility with the ecosystem
|
| 675 |
-
- 📚 Well-documented, community-validated schemas
|
| 676 |
-
- 🚀 Reduced development time (60-80% less integration work)
|
| 677 |
-
- 🤝 Multi-wallet collaboration without custom adapters
|
| 678 |
-
- 🎯 Focus on UX instead of protocol design
|
| 679 |
-
|
| 680 |
-
#### Reference Implementations
|
| 681 |
-
|
| 682 |
-
| Platform | Language | Status | Link |
|
| 683 |
-
|----------|----------|--------|------|
|
| 684 |
-
| Python (Desktop) | Python | ✅ Reference | This repo |
|
| 685 |
-
| iOS CoreML | Swift | 🔄 Community | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
|
| 686 |
-
| Android ONNX | Kotlin | 🔄 Community | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
|
| 687 |
-
| React Native | JavaScript | 📋 Wanted | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
|
| 688 |
-
| Rust (embedded) | Rust | 📋 Wanted | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
|
| 689 |
-
|
| 690 |
-
---
|
| 691 |
-
|
| 692 |
-
## 💡 Use Cases & Examples
|
| 693 |
-
|
| 694 |
-
### Real-World Privacy-Preserving Scenarios
|
| 695 |
-
|
| 696 |
-
#### 1. **Mobile Wallet with Voice Commands**
|
| 697 |
-
```
|
| 698 |
-
User speaks: "Send 100 USDC to my friend"
|
| 699 |
-
↓ (On-device speech-to-text)
|
| 700 |
-
DMind-3-nano processes: "Send 100 USDC..."
|
| 701 |
-
↓ (Local inference, <100ms)
|
| 702 |
-
Output: <start_function_call>call:EXECUTE_SWAP{...}<end_function_call>
|
| 703 |
-
↓
|
| 704 |
-
Wallet prepares transaction locally, asks for confirmation
|
| 705 |
-
✅ Zero data sent to cloud
|
| 706 |
-
```
|
| 707 |
-
|
| 708 |
-
#### 2. **Hardware Wallet with Limited Display**
|
| 709 |
-
```
|
| 710 |
-
User types on companion app: "查一下ETH价格"
|
| 711 |
-
↓
|
| 712 |
-
DMind-3-nano (quantized INT8) on secure enclave
|
| 713 |
-
↓ (2-5s inference)
|
| 714 |
-
Output: SEARCH_TOKEN{symbol:"ETH",chain:"ethereum"}
|
| 715 |
-
↓
|
| 716 |
-
Display ETH info on hardware wallet screen
|
| 717 |
-
✅ All processing in secure hardware
|
| 718 |
-
```
|
| 719 |
-
|
| 720 |
-
#### 3. **Desktop Trading Terminal**
|
| 721 |
-
```
|
| 722 |
-
Power user: "Swap 50% of my SOL portfolio to BONK and POPCAT equally"
|
| 723 |
-
↓
|
| 724 |
-
DMind-3-nano parses complex intent locally
|
| 725 |
-
↓ (<50ms on desktop CPU)
|
| 726 |
-
Output: Multiple EXECUTE_SWAP calls
|
| 727 |
-
↓
|
| 728 |
-
Terminal shows batch transaction preview
|
| 729 |
-
✅ Trading strategy never leaves local machine
|
| 730 |
-
```
|
| 731 |
-
|
| 732 |
-
#### 4. **Offline Cold Wallet Management**
|
| 733 |
-
```
|
| 734 |
-
User in air-gapped environment: "Prepare multi-sig transaction for 1000 USDT"
|
| 735 |
-
↓
|
| 736 |
-
DMind-3-nano (offline mode)
|
| 737 |
-
↓
|
| 738 |
-
Generates transaction template locally
|
| 739 |
-
↓
|
| 740 |
-
Sign and export to USB for broadcasting later
|
| 741 |
-
✅ Complete airgap maintained
|
| 742 |
-
```
|
| 743 |
-
|
| 744 |
-
### Privacy Comparison Table
|
| 745 |
-
|
| 746 |
-
| Scenario | Cloud AI | DMind-3-nano |
|
| 747 |
-
|----------|----------|--------------|
|
| 748 |
-
| User says "Buy 1000 PEPE" | 🔴 Command logged by API provider | ✅ Processed locally, no logs |
|
| 749 |
-
| Frequent trading patterns | 🔴 Profile built for ads/analytics | ✅ Private, no tracking |
|
| 750 |
-
| Whale wallet detected | 🔴 High-value user targeted | ✅ Anonymous to all |
|
| 751 |
-
| Internet outage | 🔴 Assistant unusable | ✅ Full functionality |
|
| 752 |
-
| Data breach at AI company | 🔴 Your trading history exposed | ✅ No data to breach |
|
| 753 |
-
| Government data request | 🔴 Provider must comply | ✅ No centralized data exists |
|
| 754 |
-
|
| 755 |
-
---
|
| 756 |
-
|
| 757 |
-
## 🔄 System Requirements
|
| 758 |
-
|
| 759 |
-
### Software Dependencies
|
| 760 |
-
|
| 761 |
-
| Package | Minimum Version |
|
| 762 |
-
|---------|-----------------|
|
| 763 |
-
| transformers | ≥ 4.45.0 |
|
| 764 |
-
| torch | ≥ 2.0.0 |
|
| 765 |
-
| accelerate | ≥ 0.24.0 |
|
| 766 |
-
| Python | ≥ 3.8 |
|
| 767 |
-
|
| 768 |
-
### Hardware Recommendations
|
| 769 |
-
|
| 770 |
-
#### 🖥️ Desktop/Server (Development & Testing)
|
| 771 |
-
|
| 772 |
-
| Configuration | Spec |
|
| 773 |
-
|---------------|------|
|
| 774 |
-
| **GPU** (Recommended) | NVIDIA GPU with ≥8GB VRAM |
|
| 775 |
-
| **CPU** | Modern multi-core processor |
|
| 776 |
-
| **RAM** | ≥16GB |
|
| 777 |
-
| **Storage** | ~1GB for model files |
|
| 778 |
-
|
| 779 |
-
**Performance Notes:**
|
| 780 |
-
- ✅ BF16 precision: Best performance on Ampere+ GPUs (RTX 3000+, A100, H100)
|
| 781 |
-
- ✅ FP16 precision: Good performance on older GPUs (V100, P100)
|
| 782 |
-
- ✅ CPU inference: ~100-500ms latency on modern CPUs
|
| 783 |
-
|
| 784 |
-
#### 📱 Mobile/Edge Devices (Production Deployment)
|
| 785 |
-
|
| 786 |
-
| Device Type | Requirements | Expected Performance |
|
| 787 |
-
|-------------|--------------|---------------------|
|
| 788 |
-
| **iPhone** | iPhone 12+ (A14+) | <100ms with CoreML |
|
| 789 |
-
| **Android** | Snapdragon 888+ or equivalent | <150ms with ONNX Runtime |
|
| 790 |
-
| **iPad/Tablet** | Apple M1+ or flagship Android | <80ms |
|
| 791 |
-
| **Embedded** | Raspberry Pi 4 (8GB) | ~1-2s (quantized) |
|
| 792 |
-
| **Hardware Wallets** | ARM Cortex-A (high-end) | ~2-5s (INT8 quantized) |
|
| 793 |
-
|
| 794 |
-
**Mobile Optimization Tips:**
|
| 795 |
-
- 🔹 Use **ONNX** or **CoreML** conversion for optimal mobile performance
|
| 796 |
-
- 🔹 Apply **INT8 quantization** to reduce model size to ~70MB
|
| 797 |
-
- 🔹 Enable **neural engine acceleration** on Apple devices
|
| 798 |
-
- 🔹 Use **NNAPI** on Android for hardware acceleration
|
| 799 |
-
|
| 800 |
-
---
|
| 801 |
-
|
| 802 |
-
## 🤝 Join the Privacy-First Standardization Movement
|
| 803 |
-
|
| 804 |
-
We're building an **open protocol for private, on-device crypto AI**. Help us establish industry standards that prioritize user privacy!
|
| 805 |
-
|
| 806 |
-
### How to Contribute
|
| 807 |
-
|
| 808 |
-
**🔐 For Wallet Developers:**
|
| 809 |
-
- Integrate DMind-3-nano into your mobile/desktop wallet
|
| 810 |
-
- Share mobile deployment challenges and solutions
|
| 811 |
-
- Contribute platform-specific optimizations (iOS, Android, etc.)
|
| 812 |
-
- Report real-world privacy and performance metrics
|
| 813 |
-
|
| 814 |
-
**🎯 For Model Developers:**
|
| 815 |
-
- Train specialized models using our protocols
|
| 816 |
-
- Share quantization and optimization techniques
|
| 817 |
-
- Contribute to model compression research (INT8, pruning, etc.)
|
| 818 |
-
- Benchmark on different edge devices
|
| 819 |
-
|
| 820 |
-
**🛠️ For Edge AI Engineers:**
|
| 821 |
-
- Optimize for specific hardware (Apple Neural Engine, NNAPI, etc.)
|
| 822 |
-
- Port to new platforms (Rust, WebAssembly, etc.)
|
| 823 |
-
- Improve inference speed and energy efficiency
|
| 824 |
-
- Share deployment best practices
|
| 825 |
-
|
| 826 |
-
**📖 For Protocol Designers:**
|
| 827 |
-
- Review and critique schema definitions
|
| 828 |
-
- Propose new privacy-preserving protocols
|
| 829 |
-
- Design protocols for emerging use cases (DeFi, NFTs, DAOs)
|
| 830 |
-
- Help establish governance and versioning
|
| 831 |
-
|
| 832 |
-
**📣 Spread Privacy-First AI:**
|
| 833 |
-
- Star this repository and share with wallet developers
|
| 834 |
-
- Write about the importance of on-device inference
|
| 835 |
-
- Educate users about privacy risks of cloud-based AI
|
| 836 |
-
- Organize workshops on edge AI for crypto
|
| 837 |
-
|
| 838 |
-
### Discussion Channels
|
| 839 |
-
|
| 840 |
-
- 💬 [GitHub Discussions](https://github.com/YOUR_ORG/dmind-3-nano/discussions): Protocol evolution and technical discussions
|
| 841 |
-
- 🐛 [GitHub Issues](https://github.com/YOUR_ORG/dmind-3-nano/issues): Bug reports and feature requests
|
| 842 |
-
- 📧 [Email](mailto:your-email@example.com): Partnership and collaboration inquiries
|
| 843 |
-
|
| 844 |
-
---
|
| 845 |
-
|
| 846 |
-
## ❓ Frequently Asked Questions
|
| 847 |
-
|
| 848 |
-
### General
|
| 849 |
-
|
| 850 |
-
**Q: Is my data really private?**
|
| 851 |
-
A: Yes. DMind-3-nano runs 100% on your device. No data is sent to any server, ever. You can verify this by monitoring network traffic or running in airplane mode.
|
| 852 |
-
|
| 853 |
-
**Q: How fast is inference on mobile devices?**
|
| 854 |
-
A: On modern smartphones (iPhone 12+, Snapdragon 888+), expect <100ms for intent recognition. On older devices, 200-500ms is typical.
|
| 855 |
-
|
| 856 |
-
**Q: Can I use this offline?**
|
| 857 |
-
A: Absolutely. Once downloaded, DMind-3-nano works completely offline. Perfect for cold wallets and air-gapped setups.
|
| 858 |
-
|
| 859 |
-
**Q: What's the model size?**
|
| 860 |
-
A: ~540MB (BF16), ~270MB (FP16), ~70MB (INT8 quantized). The INT8 version fits on most hardware wallets.
|
| 861 |
-
|
| 862 |
-
### Technical
|
| 863 |
-
|
| 864 |
-
**Q: How do I convert to mobile formats?**
|
| 865 |
-
A: See the [Model Conversion Guide](#model-conversion-guide) above for ONNX and CoreML conversion instructions.
|
| 866 |
-
|
| 867 |
-
**Q: Does it support custom tokens?**
|
| 868 |
-
A: Yes! The protocols are designed to handle any token symbol, address, or chain. No retraining needed for new tokens.
|
| 869 |
-
|
| 870 |
-
**Q: Can I fine-tune for my specific use case?**
|
| 871 |
-
A: Yes. While the base model works well for general crypto operations, you can fine-tune on your specific vocabulary or additional protocols.
|
| 872 |
-
|
| 873 |
-
**Q: What about languages beyond English/Chinese?**
|
| 874 |
-
A: The model has basic understanding of other languages but was primarily trained on EN/ZH. Community contributions for multilingual models are welcome!
|
| 875 |
-
|
| 876 |
-
### Integration
|
| 877 |
-
|
| 878 |
-
**Q: Which mobile frameworks are supported?**
|
| 879 |
-
A: Native iOS (CoreML), Native Android (ONNX Runtime), React Native (via ONNX), and Flutter (via TFLite) are all possible. See examples above.
|
| 880 |
-
|
| 881 |
-
**Q: Can I integrate with existing wallet backends?**
|
| 882 |
-
A: Yes! DMind-3-nano only handles intent recognition. Your existing transaction signing, RPC calls, and security logic remain unchanged.
|
| 883 |
-
|
| 884 |
-
**Q: What about hardware wallet integration?**
|
| 885 |
-
A: For secure elements with limited compute, use the INT8 quantized version. Expect 2-5s inference, which is acceptable for most use cases.
|
| 886 |
-
|
| 887 |
-
### Privacy & Security
|
| 888 |
-
|
| 889 |
-
**Q: How do you prove no data is collected?**
|
| 890 |
-
A:
|
| 891 |
-
1. Code is open source (inspect for network calls)
|
| 892 |
-
2. Model runs locally (verify via network monitoring)
|
| 893 |
-
3. Works offline (test in airplane mode)
|
| 894 |
-
4. No telemetry or analytics built in
|
| 895 |
-
|
| 896 |
-
**Q: Is on-device AI really more secure?**
|
| 897 |
-
A: Yes. Eliminates risks of:
|
| 898 |
-
- Man-in-the-middle attacks on API calls
|
| 899 |
-
- Data breaches at AI providers
|
| 900 |
-
- Server-side logging and profiling
|
| 901 |
-
- Third-party data access
|
| 902 |
-
|
| 903 |
-
**Q: What if someone steals my phone?**
|
| 904 |
-
A: The model itself contains no private data. Your wallet's existing security (biometrics, PIN, seed phrase) still applies. DMind-3-nano is just a tool for parsing commands.
|
| 905 |
-
|
| 906 |
-
---
|
| 907 |
-
|
| 908 |
-
## 📄 License & Usage
|
| 909 |
-
|
| 910 |
-
### Model License
|
| 911 |
-
|
| 912 |
-
This model is licensed under **Apache License 2.0**, allowing commercial and non-commercial use with minimal restrictions.
|
| 913 |
-
|
| 914 |
-
### Protocol License
|
| 915 |
-
|
| 916 |
-
The **SEARCH_TOKEN** and **EXECUTE_SWAP** protocol specifications are released into the **public domain** to encourage maximum adoption and standardization across the industry.
|
| 917 |
-
|
| 918 |
-
**Important Compliance Notes:**
|
| 919 |
-
- ✅ Commercial use allowed
|
| 920 |
-
- ✅ Modification and distribution allowed
|
| 921 |
-
- ✅ Protocol adoption requires no attribution
|
| 922 |
-
- ⚠️ Validate all outputs before executing financial transactions
|
| 923 |
-
- ⚠️ Comply with local cryptocurrency regulations
|
| 924 |
-
|
| 925 |
-
---
|
| 926 |
-
|
| 927 |
-
## 🙏 Acknowledgments
|
| 928 |
-
|
| 929 |
-
- **Google DeepMind** for the FunctionGemma foundation model
|
| 930 |
-
- **Hugging Face** for democratizing AI model distribution
|
| 931 |
-
- **10,000+ community testers** who validated the protocols in production
|
| 932 |
-
- The broader crypto and AI communities for inspiration and feedback
|
| 933 |
-
|
| 934 |
-
---
|
| 935 |
-
|
| 936 |
-
## 📚 Citation
|
| 937 |
-
|
| 938 |
-
If you adopt these protocols or use DMind-3-nano in your work, please cite:
|
| 939 |
-
|
| 940 |
-
```bibtex
|
| 941 |
-
@misc{dmind3nano2024,
|
| 942 |
-
title={DMind-3-nano: Standardized Function Calling Protocol for Cryptocurrency Operations},
|
| 943 |
-
author={Your Organization},
|
| 944 |
-
year={2024},
|
| 945 |
-
publisher={Hugging Face},
|
| 946 |
-
howpublished={\url{https://huggingface.co/YOUR_ORG/DMind-3-nano}},
|
| 947 |
-
note={Protocols: SEARCH\_TOKEN, EXECUTE\_SWAP}
|
| 948 |
-
}
|
| 949 |
-
```
|
| 950 |
-
|
| 951 |
-
### Protocol Version
|
| 952 |
-
|
| 953 |
-
**Current Version**: 1.0.0
|
| 954 |
-
**Released**: December 2024
|
| 955 |
-
**Status**: Production-ready
|
| 956 |
-
|
| 957 |
-
---
|
| 958 |
-
|
| 959 |
-
## 🌐 Ecosystem & Compatibility
|
| 960 |
-
|
| 961 |
-
### Protocol Adopters
|
| 962 |
-
|
| 963 |
-
We're building a community of compatible implementations. Join us!
|
| 964 |
-
|
| 965 |
-
| Project/Platform | Status | Implementation |
|
| 966 |
-
|------------------|--------|----------------|
|
| 967 |
-
| DMind-3-nano | ✅ Reference Implementation | This model |
|
| 968 |
-
| Your Project | 🔄 Seeking adopters | [Add yours!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
|
| 969 |
-
|
| 970 |
-
### Roadmap
|
| 971 |
-
|
| 972 |
-
**Privacy & Edge-First Evolution**
|
| 973 |
-
|
| 974 |
-
- ✅ **v1.0**: SEARCH_TOKEN & EXECUTE_SWAP (Current) - 270M params
|
| 975 |
-
- 🔄 **v1.1**: Multi-hop swap protocol (Q1 2025) - Mobile optimization
|
| 976 |
-
- 📋 **v1.2**: Quantized INT8 release (Q1 2025) - ~70MB for hardware wallets
|
| 977 |
-
- 📋 **v2.0**: DeFi lending/staking protocols (Q2 2025) - Privacy-preserving
|
| 978 |
-
- 📋 **v2.1**: NFT operations (Q3 2025) - On-device image understanding
|
| 979 |
-
- 📋 **v3.0**: Cross-chain bridge intent recognition (Q4 2025)
|
| 980 |
-
|
| 981 |
-
**Protocol Governance Principles:**
|
| 982 |
-
- 🔐 **Privacy-First**: Every protocol must support local inference
|
| 983 |
-
- 🔄 **Backwards Compatible**: Old versions always work
|
| 984 |
-
- 🌐 **Community-Driven**: Open RFC process for new protocols
|
| 985 |
-
- 📖 **Transparent**: All decisions documented and public
|
| 986 |
-
|
| 987 |
-
---
|
| 988 |
-
|
| 989 |
-
## 📞 Connect With Us
|
| 990 |
-
|
| 991 |
-
<div align="center">
|
| 992 |
-
|
| 993 |
-
### Get Involved
|
| 994 |
-
|
| 995 |
-
[](https://discord.gg/YOUR_DISCORD)
|
| 996 |
-
[](https://twitter.com/YOUR_HANDLE)
|
| 997 |
-
[](https://github.com/YOUR_ORG/dmind-3-nano)
|
| 998 |
-
|
| 999 |
-
**Model Hub**: [🤗 Hugging Face](https://huggingface.co/YOUR_ORG/DMind-3-nano)
|
| 1000 |
-
**Discussions**: [HF Discussions](https://huggingface.co/YOUR_ORG/DMind-3-nano/discussions)
|
| 1001 |
-
**Issues**: [GitHub Issues](https://github.com/YOUR_ORG/dmind-3-nano/issues)
|
| 1002 |
-
**Email**: your-email@example.com
|
| 1003 |
-
|
| 1004 |
-
</div>
|
| 1005 |
-
|
| 1006 |
-
---
|
| 1007 |
-
|
| 1008 |
-
<div align="center">
|
| 1009 |
-
|
| 1010 |
-
### 🚀 Let's Build the Privacy-First Crypto AI Ecosystem
|
| 1011 |
|
| 1012 |
-
|
|
|
|
|
|
|
|
|
|
| 1013 |
|
| 1014 |
-
|
| 1015 |
|
| 1016 |
-
|
| 1017 |
|
| 1018 |
-
|
|
|
|
|
|
|
| 1019 |
|
| 1020 |
-
|
| 1021 |
|
| 1022 |
-
|
| 1023 |
-
|
|
|
|
|
|
|
| 1024 |
|
| 1025 |
-
|
|
|
|
| 1 |
+
# DMind-3-nano: Privacy-First On-Device Crypto Intent Recognition
|
| 2 |
+
|
| 3 |
+
> Inference stays on your device. Standardized function calling for wallets, DEXs, and agents. Built on `google/functiongemma-270m-it`.
|
| 4 |
+
|
| 5 |
+
**Repo purpose:** host the open-source training/eval pipeline and release artifacts. Place your exported model under `./model` before pushing to Hugging Face.
|
| 6 |
+
|
| 7 |
+
## HF Card Metadata
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
language:
|
| 11 |
- en
|
| 12 |
- zh
|
|
|
|
| 27 |
- standard-protocol
|
| 28 |
library_name: transformers
|
| 29 |
pipeline_tag: text-generation
|
| 30 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
## Highlights
|
| 33 |
|
| 34 |
+
- 🔐 Privacy-first: 100% on-device intent recognition; no data leaves the device.
|
| 35 |
+
- 📱 Edge-optimized: 270M params; runs on phones/tablets/edge CPUs.
|
| 36 |
+
- 🔄 Standardized protocols: `SEARCH_TOKEN` / `EXECUTE_SWAP` with unified schemas.
|
| 37 |
+
- 🌐 Multi-chain: Solana, Ethereum, BSC, Base.
|
| 38 |
+
- 🌍 Multilingual: English + Chinese intents (Chinese samples kept in data/benchmarks).
|
| 39 |
+
- 🤖 Agent-native: designed for local-first wallet/agent workflows where a growing share of trading decisions and execution happen **on-device**.
|
| 40 |
+
- 📊 Training data: the final full fine-tune used **12,000+** samples in total; **LLM-generated data is only a subset**, and **60%+** of the data comes from **real trading scenarios**.
|
| 41 |
+
- 🧾 **(To our knowledge) first public vertical-domain FunctionGemma case study**: an end-to-end example of fine-tuning `google/functiongemma-270m-it` for a real wallet/DEX intent domain, including the practical training/evaluation pipeline and reproducible scripts.
|
| 42 |
|
| 43 |
+
## Why This Matters for Web3 (Standardization as a Step-Change)
|
| 44 |
|
| 45 |
+
Web3 is composable at the protocol layer (tokens, RPCs), but still fragmented at the **intent layer**. Today every wallet, DEX, and agent framework invents its own “swap/search intent” schema and function-calling format. The result is high integration cost, brittle adapters, inconsistent safety guarantees, and poor ecosystem interoperability.
|
| 46 |
|
| 47 |
+
This work targets a transformative goal: **standardize wallet intents** as a small, versionable protocol between natural language and transaction builders. Concretely, DMind-3-nano enforces a minimal set of typed tools (e.g. `SEARCH_TOKEN`, `EXECUTE_SWAP`) with strict schemas and a deterministic wrapper output format.
|
|
|
|
| 48 |
|
| 49 |
+
What standardization unlocks:
|
| 50 |
|
| 51 |
+
- **Interoperability**: one protocol works across wallets/DEXs/agents; integrations become plug-and-play.
|
| 52 |
+
- **Safety & auditability**: tool calls are structured data—easy to validate, simulate, policy-check, and display for confirmation before signing.
|
| 53 |
+
- **Benchmarkability**: shared datasets and comparable evaluations across models and releases.
|
| 54 |
+
- **Ecosystem scaling**: new tools can be added via versioning without breaking existing clients.
|
| 55 |
|
| 56 |
+
In short, DMind-3-nano is not only a model—it is a proposal for a **standard protocol layer** that can make wallet intelligence as interoperable as ERC-20 made tokens.
|
| 57 |
|
| 58 |
+
### The next wave: local agents executing trades
|
| 59 |
|
| 60 |
+
We expect a large share of future Web3 activity to be **agent-driven**: wallets will run local copilots that continuously parse user intent, monitor context, and propose/execute transactions. In that world, “cloud-only” intelligence becomes a bottleneck and a risk:
|
| 61 |
|
| 62 |
+
- **Privacy**: trading intent, token preferences, and behavioral signals should not be streamed to third-party servers.
|
| 63 |
+
- **Latency & reliability**: agents must work instantly and offline (mobile, hardware wallets, poor connectivity).
|
| 64 |
+
- **Security boundaries**: local agents can keep a tighter loop between intent → policy checks → simulation → user confirmation → signing.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
This is why a small, high-accuracy **on-device function-calling model** is necessary infrastructure for the agent-native wallet era—and why standardizing the intent protocol matters even more when millions of agents need to speak the same language.
|
| 67 |
|
| 68 |
+
Equally important, this repository serves as a **public reference implementation** for applying FunctionGemma to a concrete vertical domain. By openly sharing fine-tuning details (data format, training configs, evaluation, and benchmarks), it lowers the barrier for the community to replicate, extend, and standardize on a common intent protocol.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
## Model Overview
|
| 71 |
|
| 72 |
| Property | Value |
|
| 73 |
+
| --- | --- |
|
| 74 |
+
| Model | DMind-3-nano |
|
| 75 |
+
| Base | google/functiongemma-270m-it |
|
| 76 |
+
| Params | 270M |
|
| 77 |
+
| Context | 2048 |
|
| 78 |
+
| Precision | BF16 (train) |
|
| 79 |
+
| Best tokens | SOL, USDC, JUP, RAY, BONK, WIF, ETH, BTC, POPCAT, BOME, TRUMP |
|
| 80 |
+
| Chains | solana, ethereum, bsc, base |
|
| 81 |
+
|
| 82 |
+
**Experimental notice:** Highest accuracy on the token/chain set above; other assets may need further tuning. Validate outputs before transacting.
|
| 83 |
+
|
| 84 |
+
## Repository Layout
|
| 85 |
+
|
| 86 |
+
- `model/` **(add your exported model here before publishing)**
|
| 87 |
+
- `src/` training/eval utilities
|
| 88 |
+
- `train.py` (LoRA or full fine-tune)
|
| 89 |
+
- `evaluate.py` (benchmark evaluation)
|
| 90 |
+
- `prepare_dataset.py` (SFT-ready formatting)
|
| 91 |
+
- `generate_benchmark.py` (100-case benchmark)
|
| 92 |
+
- `config.py` (tools, prompts, token maps)
|
| 93 |
+
- `data/` sample data
|
| 94 |
+
- `training_data.json` (raw; open-sourced subset for reproducibility)
|
| 95 |
+
- `benchmark_dataset.json` (eval set; includes Chinese test prompts by design)
|
| 96 |
+
- `results/evaluation_results.json` sample output
|
| 97 |
+
- `run_training.sh`, `requirements.txt`
|
| 98 |
+
|
| 99 |
+
## Quick Start (Training & Eval)
|
| 100 |
+
|
| 101 |
+
Install:
|
| 102 |
+
```bash
|
| 103 |
+
pip install -r requirements.txt
|
| 104 |
```
|
| 105 |
|
| 106 |
+
Train (LoRA default):
|
| 107 |
+
```bash
|
| 108 |
+
python -m src.train \
|
| 109 |
+
--model_path /path/to/functiongemma-270m-it \
|
| 110 |
+
--dataset_path ./data/training_data.json \
|
| 111 |
+
--output_dir ./runs \
|
| 112 |
+
--bf16
|
| 113 |
```
|
| 114 |
+
Switch to full fine-tune: add `--no-use-lora`. Use `--use_4bit/--use_8bit` + `--gradient_checkpointing` for low memory.
|
| 115 |
|
| 116 |
+
Evaluate:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
```bash
|
| 118 |
+
python -m src.evaluate \
|
| 119 |
+
--model_path ./runs/<run>/final_model \
|
| 120 |
+
--benchmark_path ./data/benchmark_dataset.json \
|
| 121 |
+
--output_path ./results/eval_$(date +%Y%m%d_%H%M%S).json
|
| 122 |
```
|
| 123 |
|
| 124 |
+
Data utilities:
|
| 125 |
+
```bash
|
| 126 |
+
# Prepare SFT data
|
| 127 |
+
python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json
|
| 128 |
+
# Regenerate benchmark
|
| 129 |
+
python -m src.generate_benchmark --output ./data/benchmark_dataset.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
```
|
| 131 |
|
| 132 |
+
Note: `data/prepared_dataset.json` is a **generated artifact** (optional) and is intentionally **not committed**.
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
## Tool Protocols (Standardized)
|
| 135 |
|
| 136 |
+
**SEARCH_TOKEN** — search on-chain token info.
|
| 137 |
+
Params: `symbol`, `address`, `chain` (solana|ethereum|bsc|base), `keyword`.
|
| 138 |
|
| 139 |
+
**EXECUTE_SWAP** — execute token swap.
|
| 140 |
+
Params: `inputTokenSymbol`, `inputTokenCA`, `outputTokenCA`, `inputTokenAmount`, `inputTokenPercentage` (0-1), `outputTokenAmount`.
|
| 141 |
|
| 142 |
+
Output format:
|
| 143 |
+
```
|
| 144 |
+
<start_function_call>call:FUNCTION_NAME{key1:val1,key2:val2}<end_function_call>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
```
|
| 146 |
|
| 147 |
+
Example (Chinese input retained for coverage):
|
|
|
|
| 148 |
```
|
| 149 |
User: "查一下 Solana 上的 SOL"
|
| 150 |
Model: <start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>
|
| 151 |
```
|
| 152 |
|
| 153 |
+
## Performance Snapshot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
- Function recognition: ~96.8% on validated set
|
| 156 |
+
- Argument extraction: ~94.2%
|
| 157 |
+
- Protocol adherence: SEARCH_TOKEN 98.1%, EXECUTE_SWAP 95.3%
|
| 158 |
+
- Multi-turn success: ~92.7%
|
| 159 |
|
| 160 |
+
Scope: tokens/chains listed in **Model Overview**; outside that set may be lower.
|
| 161 |
|
| 162 |
+
## Deployment Notes
|
| 163 |
|
| 164 |
+
- On-device: convert to ONNX/CoreML/TFLite for mobile/hardware wallets; apply INT8 quant for ~70MB.
|
| 165 |
+
- CPU-only: expect sub-500ms on modern CPUs; GPUs give faster.
|
| 166 |
+
- Keep Chinese benchmark samples intact (they are intentional test cases).
|
| 167 |
|
| 168 |
+
## License & Governance
|
| 169 |
|
| 170 |
+
- Code: MIT (`LICENSE`)
|
| 171 |
+
- Model card intent: Apache-2.0 (as in metadata above)
|
| 172 |
+
- Protocol specs (SEARCH_TOKEN / EXECUTE_SWAP): public domain for maximal adoption
|
| 173 |
+
- Contributions are welcome via issues/PRs.
|
| 174 |
|
| 175 |
+
Issues/PRs welcome. When publishing to Hugging Face, ensure `./model` contains your final weights/tokenizer. Replace `YOUR_ORG/DMind-3-nano`, badges, and links with your namespace before release.
|
data/benchmark_dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/training_data.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88f5ee388ad60836142837fcc4790b7d53c2495763cb5627c1e4d78b9d58d9f9
|
| 3 |
+
size 15262245
|
requirements.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FunctionGemma SFT LoRA dependencies
|
| 2 |
+
|
| 3 |
+
# PyTorch (pick the build that matches your CUDA)
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
torchvision
|
| 6 |
+
torchaudio
|
| 7 |
+
|
| 8 |
+
# Hugging Face stack
|
| 9 |
+
transformers>=4.40.0
|
| 10 |
+
datasets>=2.18.0
|
| 11 |
+
accelerate>=0.27.0
|
| 12 |
+
tokenizers>=0.15.0
|
| 13 |
+
|
| 14 |
+
# TRL (Transformer Reinforcement Learning)
|
| 15 |
+
trl>=0.8.0
|
| 16 |
+
|
| 17 |
+
# PEFT (Parameter-Efficient Fine-Tuning)
|
| 18 |
+
peft>=0.10.0
|
| 19 |
+
|
| 20 |
+
# Quantization support (QLoRA)
|
| 21 |
+
bitsandbytes>=0.43.0
|
| 22 |
+
|
| 23 |
+
# Logging & monitoring
|
| 24 |
+
tensorboard>=2.15.0
|
| 25 |
+
wandb>=0.16.0
|
| 26 |
+
|
| 27 |
+
# Utilities
|
| 28 |
+
sentencepiece>=0.2.0
|
| 29 |
+
protobuf>=4.25.0
|
| 30 |
+
tqdm>=4.66.0
|
| 31 |
+
|
| 32 |
+
# Flash Attention (optional; install separately)
|
| 33 |
+
# pip install flash-attn --no-build-isolation
|
| 34 |
+
|
| 35 |
+
# Evaluation
|
| 36 |
+
evaluate>=0.4.0
|
results/evaluation_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
run_training.sh
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# FunctionGemma SFT LoRA quickstart
|
| 3 |
+
|
| 4 |
+
# Environment
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=0 # e.g. "0,1,2,3" for multi-GPU
|
| 6 |
+
export TOKENIZERS_PARALLELISM=false
|
| 7 |
+
|
| 8 |
+
# Model path (update to your local model location)
|
| 9 |
+
MODEL_PATH="/path/to/your/functiongemma-270m-it"
|
| 10 |
+
|
| 11 |
+
# Dataset path
|
| 12 |
+
DATASET_PATH="./data/training_data.json"
|
| 13 |
+
|
| 14 |
+
# Output directory
|
| 15 |
+
OUTPUT_DIR="./runs"
|
| 16 |
+
|
| 17 |
+
# Run name
|
| 18 |
+
RUN_NAME="functiongemma-lora-$(date +%Y%m%d_%H%M%S)"
|
| 19 |
+
|
| 20 |
+
echo "========================================"
|
| 21 |
+
echo "FunctionGemma SFT LoRA training"
|
| 22 |
+
echo "========================================"
|
| 23 |
+
echo "Model: $MODEL_PATH"
|
| 24 |
+
echo "Dataset: $DATASET_PATH"
|
| 25 |
+
echo "Output: $OUTPUT_DIR/$RUN_NAME"
|
| 26 |
+
echo "========================================"
|
| 27 |
+
|
| 28 |
+
# Option 1: Standard LoRA (recommended for most GPUs)
|
| 29 |
+
python -m src.train \
|
| 30 |
+
--model_path "$MODEL_PATH" \
|
| 31 |
+
--dataset_path "$DATASET_PATH" \
|
| 32 |
+
--output_dir "$OUTPUT_DIR" \
|
| 33 |
+
--run_name "$RUN_NAME" \
|
| 34 |
+
--lora_r 16 \
|
| 35 |
+
--lora_alpha 32 \
|
| 36 |
+
--lora_dropout 0.05 \
|
| 37 |
+
--num_train_epochs 3 \
|
| 38 |
+
--per_device_train_batch_size 4 \
|
| 39 |
+
--gradient_accumulation_steps 4 \
|
| 40 |
+
--learning_rate 5e-5 \
|
| 41 |
+
--warmup_ratio 0.1 \
|
| 42 |
+
--max_seq_length 2048 \
|
| 43 |
+
--bf16 \
|
| 44 |
+
--logging_steps 10 \
|
| 45 |
+
--save_steps 100 \
|
| 46 |
+
--eval_steps 100 \
|
| 47 |
+
--gradient_checkpointing
|
| 48 |
+
|
| 49 |
+
# Option 2: QLoRA (for smaller GPUs, uncomment to use)
|
| 50 |
+
# python -m src.train \
|
| 51 |
+
# --model_path "$MODEL_PATH" \
|
| 52 |
+
# --dataset_path "$DATASET_PATH" \
|
| 53 |
+
# --output_dir "$OUTPUT_DIR" \
|
| 54 |
+
# --run_name "$RUN_NAME-qlora" \
|
| 55 |
+
# --lora_r 16 \
|
| 56 |
+
# --lora_alpha 32 \
|
| 57 |
+
# --lora_dropout 0.05 \
|
| 58 |
+
# --num_train_epochs 3 \
|
| 59 |
+
# --per_device_train_batch_size 8 \
|
| 60 |
+
# --gradient_accumulation_steps 2 \
|
| 61 |
+
# --learning_rate 2e-4 \
|
| 62 |
+
# --warmup_ratio 0.1 \
|
| 63 |
+
# --max_seq_length 2048 \
|
| 64 |
+
# --use_4bit \
|
| 65 |
+
# --logging_steps 10 \
|
| 66 |
+
# --save_steps 100 \
|
| 67 |
+
# --eval_steps 100 \
|
| 68 |
+
# --gradient_checkpointing
|
| 69 |
+
|
| 70 |
+
echo "========================================"
|
| 71 |
+
echo "Training finished!"
|
| 72 |
+
echo "Model saved to: $OUTPUT_DIR/$RUN_NAME"
|
| 73 |
+
echo "========================================"
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Shared configuration.
|
| 4 |
+
|
| 5 |
+
Includes:
|
| 6 |
+
1. System prompt templates
|
| 7 |
+
2. Token mappings (symbol -> contract address)
|
| 8 |
+
3. Tool definitions
|
| 9 |
+
4. Supported chains
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# Token mappings
|
| 14 |
+
# ============================================================
|
| 15 |
+
|
| 16 |
+
# Tokens on Solana
|
| 17 |
+
SOLANA_TOKENS = {
|
| 18 |
+
# Native
|
| 19 |
+
"SOL": "So11111111111111111111111111111111111111112",
|
| 20 |
+
|
| 21 |
+
# Stablecoins
|
| 22 |
+
"USDC": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
|
| 23 |
+
"USDT": "Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB",
|
| 24 |
+
|
| 25 |
+
# Major tokens
|
| 26 |
+
"RAY": "4k3Dyjzvzp8eMZWUXbBCjEvwSkkk59S5iCNLY3QrkX6R",
|
| 27 |
+
"JUP": "JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN",
|
| 28 |
+
"BONK": "DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263",
|
| 29 |
+
"WIF": "EKpQGSJtjMFqKZ9KQanSqYXRcF8fBopzLHYxdM65zcjm",
|
| 30 |
+
"ORCA": "orcaEKTdK7LKz57vaAYr9QeNsVEPfiu6QeMU1kektZE",
|
| 31 |
+
"MNGO": "MangoCzJ36AjZyKwVj3VnYU4GTonjfVEnJmvvWaxLac",
|
| 32 |
+
"SRM": "SRMuApVNdxXokk5GT7XD5cUUgXMBCoAz2LHeuAoKWRt",
|
| 33 |
+
"STEP": "StepAscQoEioFxxWGnh2sLBDFp9d8rvKz2Yp39iDpyT",
|
| 34 |
+
"COPE": "8HGyAAB1yoM1ttS7pXjHMa3dukTFGQggnFFH3hJZgzQh",
|
| 35 |
+
"FIDA": "EchesyfXePKdLtoiZSL8pBe8Myagyy8ZRqsACNCFGnvp",
|
| 36 |
+
"MAPS": "MAPS41MDahZ9QdKXhVa4dWB9RuyfV4XqhyAZ8XcYepb",
|
| 37 |
+
"OXY": "z3dn17yLaGMKffVogeFHQ9zWVcXgqgf3PQnDsNs2g6M",
|
| 38 |
+
"MEDIA": "ETAtLmCmsoiEEKfNrHKJ2kYy3MoABhU6NQvpSfij5tDs",
|
| 39 |
+
"SLND": "SLNDpmoWTVADgEdndyvWzroNL7zSi1dF9PC3xHGtPwp",
|
| 40 |
+
"SAMO": "7xKXtg2CW87d97TXJSDpbD5jBkheTqA83TZRuJosgAsU",
|
| 41 |
+
|
| 42 |
+
# Meme tokens
|
| 43 |
+
"POPCAT": "7GCihgDB8fe6KNjn2MYtkzZcRjQy3t9GHdC8uHYmW2hr",
|
| 44 |
+
"TRUMP": "6p6xgHyF7AeE6TZkSmFsko444wqoP15icUSqi2jfGiPN",
|
| 45 |
+
"MELANIA": "FUAfBo2jgks6gB4Z4LfZkqSZgzNucisEHqnNebaRxM1P",
|
| 46 |
+
"PEPE": "CLoUDKc4Ane7HeQcPpE3YHnznRxhMimJ4MyaUqyHFzAu",
|
| 47 |
+
"DOGE": "9nEqaUcb16sQ3Tn1psbkWqyhPdLmfHWjKGymREjsAgTE",
|
| 48 |
+
"SHIB": "CiKu4eHsVrc1eueVQeHn7qhXTcVu95gSQmBpX4utjL9z",
|
| 49 |
+
"FLOKI": "FLKiaqMsaLwSZLqnCpSFHktMUQsXgcrhFN3TZg1Zj2EN",
|
| 50 |
+
"WOJAK": "8cMrxCkREwszByAj9hzn5qKLnLjvLDU7gPYtWQgfwSAS",
|
| 51 |
+
"GIGA": "63LfDmNb3MQ8mw9MtZ2To9bEA2M71kZUUGq5tiJxcqj9",
|
| 52 |
+
"BOME": "ukHH6c7mMyiWCf1b9pnWe25TSpkDDt3H5pQZgZ74J82",
|
| 53 |
+
"SLERF": "7BgBvyjrZX1YKz4oh9mjb8ZScatkkwb8DzFx7LoiVkM3",
|
| 54 |
+
"MEW": "MEW1gQWJ3nEXg2qgERiKu7FAFj79PHvQVREQUzScPP5",
|
| 55 |
+
"MYRO": "HhJpBhRRn4g56VsyLuT8DL5Bv31HkXqsrahTTUCZeZg4",
|
| 56 |
+
"PONKE": "5z3EqYQo9HiCEs3R84RCDMu2n7anpDMxRhdK8PSWmrRC",
|
| 57 |
+
"MOODENG": "ED5nyyWEzpPPiWimP8vYm7sD7TD3LAt3Q3gRTWHzPJBY",
|
| 58 |
+
|
| 59 |
+
# DeFi tokens
|
| 60 |
+
"PYTH": "HZ1JovNiVvGrGNiiYvEozEVgZ58xaU3RKwX8eACQBCt3",
|
| 61 |
+
"JTO": "jtojtomepa8beP8AuQc6eXt5FriJwfFMwQx2v2f9mCL",
|
| 62 |
+
"W": "85VBFQZC9TZkfaptBWjvUw7YbZjy52A6mjtPGjstQAmQ",
|
| 63 |
+
"RENDER": "rndrizKT3MK1iimdxRdWabcF7Zg7AR5T4nud4EkHBof",
|
| 64 |
+
"HNT": "hntyVP6YFm1Hg25TN9WGLqM12b8TQmcknKrdu1oxWux",
|
| 65 |
+
"INJ": "6McPRfPV6bY1e9hLxWyG54W9i9Epq75QBvXg2oetBVTB",
|
| 66 |
+
|
| 67 |
+
# Wrapped tokens
|
| 68 |
+
"WETH": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", # same as WETH
|
| 69 |
+
"WBTC": "3NZ9JMVBmGAqocybic2c7LQCJScmgsAZ6vQqTDzcqmJh", # same as WBTC
|
| 70 |
+
"ETH": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", # same as WETH
|
| 71 |
+
"BTC": "3NZ9JMVBmGAqocybic2c7LQCJScmgsAZ6vQqTDzcqmJh", # same as WBTC
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Tokens on Ethereum
|
| 75 |
+
ETHEREUM_TOKENS = {
|
| 76 |
+
"ETH": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
|
| 77 |
+
"WETH": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2",
|
| 78 |
+
"USDC": "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48",
|
| 79 |
+
"USDT": "0xdAC17F958D2ee523a2206206994597C13D831ec7",
|
| 80 |
+
"DAI": "0x6B175474E89094C44Da98b954EescdeCB5BE3830",
|
| 81 |
+
"WBTC": "0x2260FAC5E5542a773Aa44fBCfeDf7C193bc2C599",
|
| 82 |
+
"UNI": "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984",
|
| 83 |
+
"LINK": "0x514910771AF9Ca656af840dff83E8264EcF986CA",
|
| 84 |
+
"AAVE": "0x7Fc66500c84A76Ad7e9c93437bFc5Ac33E2DDaE9",
|
| 85 |
+
"PEPE": "0x6982508145454Ce325dDbE47a25d4ec3d2311933",
|
| 86 |
+
"SHIB": "0x95aD61b0a150d79219dCF64E1E6Cc01f0B64C4cE",
|
| 87 |
+
"TRUMP": "0x576e2BeD8F7b46D34016198911Cdf9886f78bea7",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Tokens on BSC
|
| 91 |
+
BSC_TOKENS = {
|
| 92 |
+
"BNB": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
|
| 93 |
+
"WBNB": "0xbb4CdB9CBd36B01bD1cBaEBF2De08d9173bc095c",
|
| 94 |
+
"USDC": "0x8AC76a51cc950d9822D68b83fE1Ad97B32Cd580d",
|
| 95 |
+
"USDT": "0x55d398326f99059fF775485246999027B3197955",
|
| 96 |
+
"BUSD": "0xe9e7CEA3DedcA5984780Bafc599bD69ADd087D56",
|
| 97 |
+
"CAKE": "0x0E09FaBB73Bd3Ade0a17ECC321fD13a19e81cE82",
|
| 98 |
+
"DOGE": "0xbA2aE424d960c26247Dd6c32edC70B295c744C43",
|
| 99 |
+
"SHIB": "0x2859e4544C4bB03966803b044A93563Bd2D0DD4D",
|
| 100 |
+
"FLOKI": "0xfb5B838b6cfEEdC2873aB27866079AC55363D37E",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Tokens on Base
|
| 104 |
+
BASE_TOKENS = {
|
| 105 |
+
"ETH": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
|
| 106 |
+
"WETH": "0x4200000000000000000000000000000000000006",
|
| 107 |
+
"USDC": "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913",
|
| 108 |
+
"DAI": "0x50c5725949A6F0c72E6C4a641F24049A917DB0Cb",
|
| 109 |
+
"BRETT": "0x532f27101965dd16442E59d40670FaF5eBB142E4",
|
| 110 |
+
"DEGEN": "0x4ed4E862860beD51a9570b96d89aF5E1B0Efefed",
|
| 111 |
+
"TOSHI": "0xAC1Bd2486aAf3B5C0fc3Fd868558b082a531B2B4",
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Token mappings for all chains
|
| 115 |
+
ALL_TOKENS = {
|
| 116 |
+
"solana": SOLANA_TOKENS,
|
| 117 |
+
"ethereum": ETHEREUM_TOKENS,
|
| 118 |
+
"bsc": BSC_TOKENS,
|
| 119 |
+
"base": BASE_TOKENS,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Supported chains
|
| 123 |
+
SUPPORTED_CHAINS = ["solana", "ethereum", "bsc", "base"]
|
| 124 |
+
|
| 125 |
+
# Default chain
|
| 126 |
+
DEFAULT_CHAIN = "solana"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ============================================================
|
| 130 |
+
# System prompt templates
|
| 131 |
+
# ============================================================
|
| 132 |
+
|
| 133 |
+
def get_system_prompt(chain: str = "solana") -> str:
|
| 134 |
+
"""
|
| 135 |
+
Build a system prompt that contains the token mapping for a chain.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
chain: blockchain name
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
system prompt string
|
| 142 |
+
"""
|
| 143 |
+
tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
|
| 144 |
+
|
| 145 |
+
# Build the token mapping string
|
| 146 |
+
token_list = []
|
| 147 |
+
for symbol, address in tokens.items():
|
| 148 |
+
token_list.append(f" - {symbol}: {address}")
|
| 149 |
+
token_mapping_str = "\n".join(token_list)
|
| 150 |
+
|
| 151 |
+
system_prompt = f"""You are a blockchain trading assistant that helps users search for tokens and execute swaps.
|
| 152 |
+
|
| 153 |
+
## Available Tools
|
| 154 |
+
|
| 155 |
+
1. **SEARCH_TOKEN**: Search for token information on the blockchain.
|
| 156 |
+
- Parameters: symbol, address, chain, keyword
|
| 157 |
+
- Use when user wants to find/query/search token information
|
| 158 |
+
|
| 159 |
+
2. **EXECUTE_SWAP**: Execute a token swap on the blockchain.
|
| 160 |
+
- Parameters: inputTokenCA, outputTokenCA, inputTokenAmount, inputTokenPercentage
|
| 161 |
+
- Use when user wants to buy/sell/swap/convert tokens
|
| 162 |
+
|
| 163 |
+
## Token Address Mapping ({chain.upper()})
|
| 164 |
+
|
| 165 |
+
{token_mapping_str}
|
| 166 |
+
|
| 167 |
+
## Important Rules
|
| 168 |
+
|
| 169 |
+
1. When user mentions a token symbol (like SOL, USDC, RAY), use the corresponding contract address from the mapping above.
|
| 170 |
+
2. For buying tokens: inputTokenCA is what user spends (usually SOL/ETH), outputTokenCA is what user receives.
|
| 171 |
+
3. For selling tokens: inputTokenCA is what user sells, outputTokenCA is usually SOL/ETH.
|
| 172 |
+
4. If user specifies an amount, use inputTokenAmount. If user specifies a percentage (like "all", "50%", "half"), use inputTokenPercentage.
|
| 173 |
+
5. inputTokenPercentage should be a decimal between 0 and 1 (e.g., 0.5 for 50%, 1.0 for 100%).
|
| 174 |
+
6. If the user's request is unclear or missing required information, ask for clarification instead of making assumptions.
|
| 175 |
+
7. Do NOT call any function if the request is unrelated to token search or swap (e.g., weather, greetings, jokes).
|
| 176 |
+
|
| 177 |
+
## Examples
|
| 178 |
+
|
| 179 |
+
- "Buy 2 SOL of RAY" → EXECUTE_SWAP with inputTokenCA=SOL address, outputTokenCA=RAY address, inputTokenAmount="2"
|
| 180 |
+
- "Sell all my JUP" → EXECUTE_SWAP with inputTokenCA=JUP address, outputTokenCA=SOL address, inputTokenPercentage=1.0
|
| 181 |
+
- "Search for BONK" → SEARCH_TOKEN with symbol="BONK", chain="{chain}"
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
return system_prompt
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Shorter system prompt (used for training)
|
| 188 |
+
def get_system_prompt_short(chain: str = "solana") -> str:
|
| 189 |
+
"""Build the concise system prompt."""
|
| 190 |
+
tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
|
| 191 |
+
|
| 192 |
+
# Only include common tokens
|
| 193 |
+
common_tokens = ["SOL", "USDC", "USDT", "RAY", "JUP", "BONK", "WIF", "TRUMP", "POPCAT", "PEPE", "ETH", "WETH", "BTC", "WBTC"]
|
| 194 |
+
token_list = []
|
| 195 |
+
for symbol in common_tokens:
|
| 196 |
+
if symbol in tokens:
|
| 197 |
+
token_list.append(f"{symbol}:{tokens[symbol]}")
|
| 198 |
+
token_mapping_str = ", ".join(token_list)
|
| 199 |
+
|
| 200 |
+
system_prompt = f"""You are a blockchain trading assistant. Use SEARCH_TOKEN to find tokens and EXECUTE_SWAP to trade tokens.
|
| 201 |
+
|
| 202 |
+
Token addresses ({chain}): {token_mapping_str}
|
| 203 |
+
|
| 204 |
+
Rules:
|
| 205 |
+
- For buy: inputTokenCA=payment token, outputTokenCA=target token
|
| 206 |
+
- For sell: inputTokenCA=token to sell, outputTokenCA=SOL/ETH
|
| 207 |
+
- Use inputTokenAmount for specific amounts, inputTokenPercentage (0-1) for percentages
|
| 208 |
+
- Ask for clarification if request is unclear
|
| 209 |
+
- Do NOT call functions for unrelated requests (weather, greetings, etc.)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
return system_prompt
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ============================================================
|
| 216 |
+
# Tool definitions
|
| 217 |
+
# ============================================================
|
| 218 |
+
|
| 219 |
+
TOOLS = [
|
| 220 |
+
{
|
| 221 |
+
"type": "function",
|
| 222 |
+
"function": {
|
| 223 |
+
"name": "SEARCH_TOKEN",
|
| 224 |
+
"description": "Search for token information on the blockchain. Use this when user wants to find, query, or get information about a token.",
|
| 225 |
+
"parameters": {
|
| 226 |
+
"type": "object",
|
| 227 |
+
"properties": {
|
| 228 |
+
"symbol": {
|
| 229 |
+
"type": ["string", "null"],
|
| 230 |
+
"description": "Symbol of the token (e.g., SOL, USDC, RAY)"
|
| 231 |
+
},
|
| 232 |
+
"address": {
|
| 233 |
+
"type": ["string", "null"],
|
| 234 |
+
"description": "Contract address of the token"
|
| 235 |
+
},
|
| 236 |
+
"chain": {
|
| 237 |
+
"type": "string",
|
| 238 |
+
"enum": ["solana", "ethereum", "bsc", "base"],
|
| 239 |
+
"description": "Blockchain to search on"
|
| 240 |
+
},
|
| 241 |
+
"keyword": {
|
| 242 |
+
"type": ["string", "null"],
|
| 243 |
+
"description": "Keyword to search for the token"
|
| 244 |
+
}
|
| 245 |
+
},
|
| 246 |
+
"required": []
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"type": "function",
|
| 252 |
+
"function": {
|
| 253 |
+
"name": "EXECUTE_SWAP",
|
| 254 |
+
"description": "Execute a token swap on the blockchain. Use this when user wants to buy, sell, swap, or convert tokens.",
|
| 255 |
+
"parameters": {
|
| 256 |
+
"type": "object",
|
| 257 |
+
"properties": {
|
| 258 |
+
"inputTokenCA": {
|
| 259 |
+
"type": ["string", "null"],
|
| 260 |
+
"description": "Contract address of the token to spend/sell"
|
| 261 |
+
},
|
| 262 |
+
"outputTokenCA": {
|
| 263 |
+
"type": ["string", "null"],
|
| 264 |
+
"description": "Contract address of the token to receive/buy"
|
| 265 |
+
},
|
| 266 |
+
"inputTokenAmount": {
|
| 267 |
+
"type": ["string", "null"],
|
| 268 |
+
"description": "Exact amount of input token to swap"
|
| 269 |
+
},
|
| 270 |
+
"inputTokenPercentage": {
|
| 271 |
+
"type": ["number", "null"],
|
| 272 |
+
"description": "Percentage of input token balance to swap (0-1, e.g., 0.5 for 50%)"
|
| 273 |
+
}
|
| 274 |
+
},
|
| 275 |
+
"required": ["inputTokenCA", "outputTokenCA"]
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
]
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ============================================================
|
| 283 |
+
# Helper functions
|
| 284 |
+
# ============================================================
|
| 285 |
+
|
| 286 |
+
def get_token_address(symbol: str, chain: str = "solana") -> str:
|
| 287 |
+
"""Get token address by symbol."""
|
| 288 |
+
tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
|
| 289 |
+
return tokens.get(symbol.upper(), None)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def get_native_token(chain: str = "solana") -> tuple:
|
| 293 |
+
"""Get the native token (symbol, address) for a chain."""
|
| 294 |
+
native_tokens = {
|
| 295 |
+
"solana": ("SOL", SOLANA_TOKENS["SOL"]),
|
| 296 |
+
"ethereum": ("ETH", ETHEREUM_TOKENS["ETH"]),
|
| 297 |
+
"bsc": ("BNB", BSC_TOKENS["BNB"]),
|
| 298 |
+
"base": ("ETH", BASE_TOKENS["ETH"]),
|
| 299 |
+
}
|
| 300 |
+
return native_tokens.get(chain, ("SOL", SOLANA_TOKENS["SOL"]))
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_all_token_symbols(chain: str = "solana") -> list:
|
| 304 |
+
"""Return all token symbols on a chain."""
|
| 305 |
+
tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
|
| 306 |
+
return list(tokens.keys())
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
# Quick self-test
|
| 311 |
+
print("=== System Prompt (Full) ===")
|
| 312 |
+
print(get_system_prompt("solana")[:500] + "...")
|
| 313 |
+
print("\n=== System Prompt (Short) ===")
|
| 314 |
+
print(get_system_prompt_short("solana"))
|
| 315 |
+
print("\n=== Token Address ===")
|
| 316 |
+
print(f"RAY: {get_token_address('RAY', 'solana')}")
|
| 317 |
+
print(f"TRUMP: {get_token_address('TRUMP', 'solana')}")
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
FunctionGemma evaluation script (v2).
|
| 4 |
+
|
| 5 |
+
Uses a unified system prompt for evaluation.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python -m src.evaluate --model_path ./runs/<run>/final_model --benchmark_path ./data/benchmark_dataset.json
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import argparse
|
| 16 |
+
import logging
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Dict, List, Optional, Tuple
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
+
from threading import Lock
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 25 |
+
from peft import PeftModel
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
# Import config
|
| 29 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 30 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 31 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 32 |
+
|
| 33 |
+
DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json"
|
| 34 |
+
DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results"
|
| 35 |
+
|
| 36 |
+
from src.config import ( # noqa: E402
|
| 37 |
+
get_system_prompt, get_system_prompt_short, TOOLS,
|
| 38 |
+
SOLANA_TOKENS, get_token_address
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Logging
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 45 |
+
)
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_model(
|
| 50 |
+
model_path: str,
|
| 51 |
+
lora_path: Optional[str] = None,
|
| 52 |
+
device: str = "auto",
|
| 53 |
+
load_in_8bit: bool = False,
|
| 54 |
+
load_in_4bit: bool = False,
|
| 55 |
+
):
|
| 56 |
+
"""Load model and tokenizer."""
|
| 57 |
+
logger.info(f"Loading model: {model_path}")
|
| 58 |
+
|
| 59 |
+
kwargs = {
|
| 60 |
+
"device_map": device,
|
| 61 |
+
"trust_remote_code": True,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if load_in_8bit:
|
| 65 |
+
kwargs["load_in_8bit"] = True
|
| 66 |
+
elif load_in_4bit:
|
| 67 |
+
from transformers import BitsAndBytesConfig
|
| 68 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 69 |
+
load_in_4bit=True,
|
| 70 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 71 |
+
bnb_4bit_use_double_quant=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
kwargs["torch_dtype"] = torch.bfloat16
|
| 76 |
+
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 78 |
+
if tokenizer.pad_token is None:
|
| 79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 80 |
+
|
| 81 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 82 |
+
|
| 83 |
+
if lora_path:
|
| 84 |
+
logger.info(f"Loading LoRA adapter: {lora_path}")
|
| 85 |
+
model = PeftModel.from_pretrained(model, lora_path)
|
| 86 |
+
|
| 87 |
+
model.eval()
|
| 88 |
+
return model, tokenizer
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]:
|
| 92 |
+
"""
|
| 93 |
+
Parse FunctionGemma formatted output.
|
| 94 |
+
|
| 95 |
+
Format: <start_function_call>call:FUNC_NAME{key:<escape>value<escape>,...}<end_function_call>
|
| 96 |
+
"""
|
| 97 |
+
# full match
|
| 98 |
+
pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>'
|
| 99 |
+
match = re.search(pattern, response)
|
| 100 |
+
|
| 101 |
+
if not match:
|
| 102 |
+
# partial match (truncated)
|
| 103 |
+
pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}'
|
| 104 |
+
match = re.search(pattern, response)
|
| 105 |
+
|
| 106 |
+
if not match:
|
| 107 |
+
# match function name only
|
| 108 |
+
pattern = r'<start_function_call>call:(\w+)'
|
| 109 |
+
match = re.search(pattern, response)
|
| 110 |
+
if match:
|
| 111 |
+
return match.group(1), {}
|
| 112 |
+
|
| 113 |
+
# fallback: look for function names
|
| 114 |
+
for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]:
|
| 115 |
+
if func in response:
|
| 116 |
+
return func, {}
|
| 117 |
+
|
| 118 |
+
return None, None
|
| 119 |
+
|
| 120 |
+
func_name = match.group(1)
|
| 121 |
+
params_str = match.group(2) if len(match.groups()) > 1 else ""
|
| 122 |
+
|
| 123 |
+
# parse arguments
|
| 124 |
+
args = parse_params_string(params_str)
|
| 125 |
+
|
| 126 |
+
return func_name, args
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def parse_params_string(params_str: str) -> Dict:
|
| 130 |
+
"""Parse parameter string."""
|
| 131 |
+
args = {}
|
| 132 |
+
if not params_str:
|
| 133 |
+
return args
|
| 134 |
+
|
| 135 |
+
# pattern: key:<escape>value<escape> or key:value
|
| 136 |
+
param_pattern = r'(\w+):(?:<escape>([^<]*)<escape>|([^,}]+))'
|
| 137 |
+
|
| 138 |
+
for match in re.finditer(param_pattern, params_str):
|
| 139 |
+
key = match.group(1)
|
| 140 |
+
value = match.group(2) if match.group(2) is not None else match.group(3)
|
| 141 |
+
|
| 142 |
+
if value is None:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
value = value.strip()
|
| 146 |
+
|
| 147 |
+
# handle percentage
|
| 148 |
+
if value.endswith('%'):
|
| 149 |
+
try:
|
| 150 |
+
args[key] = float(value[:-1]) / 100
|
| 151 |
+
continue
|
| 152 |
+
except ValueError:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
# attempt numeric conversion
|
| 156 |
+
try:
|
| 157 |
+
if '.' in value:
|
| 158 |
+
args[key] = float(value)
|
| 159 |
+
else:
|
| 160 |
+
args[key] = int(value)
|
| 161 |
+
except ValueError:
|
| 162 |
+
args[key] = value
|
| 163 |
+
|
| 164 |
+
return args
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def is_rejection_response(response: str) -> bool:
|
| 168 |
+
"""Check if the response is a rejection/clarification."""
|
| 169 |
+
# no function call markers
|
| 170 |
+
if '<start_function_call>' not in response:
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
# check clarification/rejection keywords (keep Chinese variants for CN prompts)
|
| 174 |
+
rejection_keywords = [
|
| 175 |
+
"please specify", "could you", "what token", "which token",
|
| 176 |
+
"请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币",
|
| 177 |
+
"sorry", "can't", "cannot", "unable", "抱歉", "无法",
|
| 178 |
+
"more information", "more details", "更多信息",
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
response_lower = response.lower()
|
| 182 |
+
for keyword in rejection_keywords:
|
| 183 |
+
if keyword.lower() in response_lower:
|
| 184 |
+
return True
|
| 185 |
+
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def format_messages_for_model(
|
| 190 |
+
messages: List[Dict],
|
| 191 |
+
tokenizer,
|
| 192 |
+
tools: List[Dict] = None,
|
| 193 |
+
) -> str:
|
| 194 |
+
"""Format messages into the model chat template."""
|
| 195 |
+
if hasattr(tokenizer, 'apply_chat_template'):
|
| 196 |
+
try:
|
| 197 |
+
return tokenizer.apply_chat_template(
|
| 198 |
+
messages,
|
| 199 |
+
tools=tools,
|
| 200 |
+
tokenize=False,
|
| 201 |
+
add_generation_prompt=True,
|
| 202 |
+
)
|
| 203 |
+
except Exception:
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
# Manual formatting fallback
|
| 207 |
+
formatted = ""
|
| 208 |
+
for msg in messages:
|
| 209 |
+
role = msg["role"]
|
| 210 |
+
content = msg["content"]
|
| 211 |
+
|
| 212 |
+
if role == "system":
|
| 213 |
+
formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n"
|
| 214 |
+
elif role == "user":
|
| 215 |
+
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
| 216 |
+
elif role == "assistant":
|
| 217 |
+
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
| 218 |
+
|
| 219 |
+
formatted += "<start_of_turn>model\n"
|
| 220 |
+
return formatted
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def generate_response(
|
| 224 |
+
model,
|
| 225 |
+
tokenizer,
|
| 226 |
+
prompt: str,
|
| 227 |
+
system_prompt: str,
|
| 228 |
+
max_new_tokens: int = 256,
|
| 229 |
+
) -> str:
|
| 230 |
+
"""Generate model response."""
|
| 231 |
+
messages = [
|
| 232 |
+
{"role": "system", "content": system_prompt},
|
| 233 |
+
{"role": "user", "content": prompt},
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
input_text = format_messages_for_model(messages, tokenizer, TOOLS)
|
| 237 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
| 238 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 239 |
+
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
outputs = model.generate(
|
| 242 |
+
**inputs,
|
| 243 |
+
max_new_tokens=max_new_tokens,
|
| 244 |
+
temperature=0.1,
|
| 245 |
+
do_sample=True,
|
| 246 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 247 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
|
| 251 |
+
response = response.replace("<end_of_turn>", "").strip()
|
| 252 |
+
|
| 253 |
+
return response
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]:
|
| 257 |
+
"""Compare expected vs actual arguments."""
|
| 258 |
+
if not expected:
|
| 259 |
+
return 1.0 if not actual else 0.0, []
|
| 260 |
+
|
| 261 |
+
if not actual:
|
| 262 |
+
return 0.0, ["No arguments extracted"]
|
| 263 |
+
|
| 264 |
+
errors = []
|
| 265 |
+
total_keys = set(expected.keys()) | set(actual.keys())
|
| 266 |
+
|
| 267 |
+
if not total_keys:
|
| 268 |
+
return 1.0, []
|
| 269 |
+
|
| 270 |
+
matched = 0
|
| 271 |
+
|
| 272 |
+
for key in expected.keys():
|
| 273 |
+
exp_val = expected.get(key)
|
| 274 |
+
act_val = actual.get(key)
|
| 275 |
+
|
| 276 |
+
if exp_val is None:
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
if act_val is None:
|
| 280 |
+
errors.append(f"Missing key: {key}")
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
# Compare values
|
| 284 |
+
if str(exp_val) == str(act_val):
|
| 285 |
+
matched += 1
|
| 286 |
+
elif isinstance(exp_val, str) and isinstance(act_val, str):
|
| 287 |
+
# Partial match (contract address prefix)
|
| 288 |
+
if exp_val[:10] == act_val[:10]:
|
| 289 |
+
matched += 0.5
|
| 290 |
+
errors.append(f"Partial match for {key}")
|
| 291 |
+
else:
|
| 292 |
+
errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}")
|
| 293 |
+
elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)):
|
| 294 |
+
if abs(float(exp_val) - float(act_val)) < 0.01:
|
| 295 |
+
matched += 1
|
| 296 |
+
else:
|
| 297 |
+
errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}")
|
| 298 |
+
else:
|
| 299 |
+
errors.append(f"Type mismatch for {key}")
|
| 300 |
+
|
| 301 |
+
# Check extra keys
|
| 302 |
+
for key in actual.keys():
|
| 303 |
+
if key not in expected:
|
| 304 |
+
errors.append(f"Extra key: {key}")
|
| 305 |
+
|
| 306 |
+
score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0
|
| 307 |
+
return score, errors
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def process_single_sample(
|
| 311 |
+
sample: Dict,
|
| 312 |
+
idx: int,
|
| 313 |
+
model,
|
| 314 |
+
tokenizer,
|
| 315 |
+
system_prompt: str,
|
| 316 |
+
) -> Dict:
|
| 317 |
+
"""Process one sample and return evaluation result."""
|
| 318 |
+
sample_id = sample.get("id", idx + 1)
|
| 319 |
+
category = sample.get("category", "unknown")
|
| 320 |
+
user_input = sample["input"]
|
| 321 |
+
expected_func = sample["expected"]["function_name"]
|
| 322 |
+
expected_args = sample["expected"].get("arguments", {})
|
| 323 |
+
|
| 324 |
+
# Extract user message
|
| 325 |
+
if isinstance(user_input, dict) and "messages" in user_input:
|
| 326 |
+
prompt = ""
|
| 327 |
+
for msg in user_input["messages"]:
|
| 328 |
+
if msg.get("role") == "user":
|
| 329 |
+
prompt = msg.get("content", "")
|
| 330 |
+
break
|
| 331 |
+
else:
|
| 332 |
+
prompt = str(user_input)
|
| 333 |
+
|
| 334 |
+
# Generate response
|
| 335 |
+
response = generate_response(model, tokenizer, prompt, system_prompt)
|
| 336 |
+
|
| 337 |
+
# Parse response
|
| 338 |
+
actual_func, actual_args = parse_functiongemma_output(response)
|
| 339 |
+
is_rejection = is_rejection_response(response)
|
| 340 |
+
|
| 341 |
+
# Evaluate
|
| 342 |
+
func_correct = False
|
| 343 |
+
args_correct = False
|
| 344 |
+
exact_match = False
|
| 345 |
+
arg_score = 0.0
|
| 346 |
+
error_msg = None
|
| 347 |
+
rejection_correct = False
|
| 348 |
+
|
| 349 |
+
if expected_func is None:
|
| 350 |
+
# Expecting rejection
|
| 351 |
+
func_correct = is_rejection or actual_func is None
|
| 352 |
+
args_correct = func_correct
|
| 353 |
+
exact_match = func_correct
|
| 354 |
+
arg_score = 1.0 if func_correct else 0.0
|
| 355 |
+
rejection_correct = func_correct
|
| 356 |
+
|
| 357 |
+
if not func_correct:
|
| 358 |
+
error_msg = f"Expected rejection, got {actual_func}"
|
| 359 |
+
else:
|
| 360 |
+
# Expecting a function call
|
| 361 |
+
func_correct = actual_func == expected_func
|
| 362 |
+
|
| 363 |
+
if func_correct:
|
| 364 |
+
# Compare arguments
|
| 365 |
+
arg_score, arg_errors = compare_arguments(expected_args, actual_args or {})
|
| 366 |
+
args_correct = arg_score >= 0.99
|
| 367 |
+
exact_match = args_correct
|
| 368 |
+
|
| 369 |
+
if not args_correct:
|
| 370 |
+
error_msg = "; ".join(arg_errors)
|
| 371 |
+
else:
|
| 372 |
+
error_msg = f"Expected {expected_func}, got {actual_func}"
|
| 373 |
+
|
| 374 |
+
# Return result
|
| 375 |
+
result = {
|
| 376 |
+
"sample_id": sample_id,
|
| 377 |
+
"category": category,
|
| 378 |
+
"expected_func": expected_func,
|
| 379 |
+
"actual_func": actual_func,
|
| 380 |
+
"func_correct": func_correct,
|
| 381 |
+
"args_correct": args_correct,
|
| 382 |
+
"exact_match": exact_match,
|
| 383 |
+
"rejection_correct": rejection_correct,
|
| 384 |
+
"arg_score": arg_score,
|
| 385 |
+
"error_msg": error_msg,
|
| 386 |
+
"user_input": user_input,
|
| 387 |
+
"expected_args": expected_args,
|
| 388 |
+
"actual_args": actual_args,
|
| 389 |
+
"response": response,
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
return result
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def evaluate_benchmark(
|
| 396 |
+
model,
|
| 397 |
+
tokenizer,
|
| 398 |
+
benchmark: List[Dict],
|
| 399 |
+
chain: str = "solana",
|
| 400 |
+
verbose: bool = False,
|
| 401 |
+
num_workers: int = 1,
|
| 402 |
+
) -> Dict:
|
| 403 |
+
"""Evaluate the benchmark (supports concurrency)."""
|
| 404 |
+
system_prompt = get_system_prompt_short(chain)
|
| 405 |
+
|
| 406 |
+
results = {
|
| 407 |
+
"total": len(benchmark),
|
| 408 |
+
"function_correct": 0,
|
| 409 |
+
"arguments_correct": 0,
|
| 410 |
+
"exact_match": 0,
|
| 411 |
+
"rejection_correct": 0,
|
| 412 |
+
"total_arg_score": 0.0,
|
| 413 |
+
"by_category": {},
|
| 414 |
+
"by_function": {},
|
| 415 |
+
"errors": [],
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# Protect result updates with a lock
|
| 419 |
+
results_lock = Lock()
|
| 420 |
+
|
| 421 |
+
# Concurrent processing
|
| 422 |
+
if num_workers > 1:
|
| 423 |
+
logger.info(f"Evaluating with {num_workers} worker threads")
|
| 424 |
+
|
| 425 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 426 |
+
# Submit tasks
|
| 427 |
+
futures = {
|
| 428 |
+
executor.submit(
|
| 429 |
+
process_single_sample,
|
| 430 |
+
sample, i, model, tokenizer, system_prompt
|
| 431 |
+
): i for i, sample in enumerate(benchmark)
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
# Progress bar
|
| 435 |
+
with tqdm(total=len(benchmark), desc="Evaluation") as pbar:
|
| 436 |
+
for future in as_completed(futures):
|
| 437 |
+
result = future.result()
|
| 438 |
+
|
| 439 |
+
# Update results (locked)
|
| 440 |
+
with results_lock:
|
| 441 |
+
_update_results(results, result, verbose)
|
| 442 |
+
|
| 443 |
+
pbar.update(1)
|
| 444 |
+
else:
|
| 445 |
+
# Serial path
|
| 446 |
+
logger.info("Evaluating with a single thread")
|
| 447 |
+
for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")):
|
| 448 |
+
result = process_single_sample(sample, i, model, tokenizer, system_prompt)
|
| 449 |
+
_update_results(results, result, verbose)
|
| 450 |
+
|
| 451 |
+
return results
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _update_results(results: Dict, result: Dict, verbose: bool):
|
| 455 |
+
"""Update aggregated evaluation results."""
|
| 456 |
+
sample_id = result["sample_id"]
|
| 457 |
+
category = result["category"]
|
| 458 |
+
expected_func = result["expected_func"]
|
| 459 |
+
actual_func = result["actual_func"]
|
| 460 |
+
func_correct = result["func_correct"]
|
| 461 |
+
args_correct = result["args_correct"]
|
| 462 |
+
exact_match = result["exact_match"]
|
| 463 |
+
rejection_correct = result["rejection_correct"]
|
| 464 |
+
arg_score = result["arg_score"]
|
| 465 |
+
error_msg = result["error_msg"]
|
| 466 |
+
|
| 467 |
+
# Overall stats
|
| 468 |
+
if func_correct:
|
| 469 |
+
results["function_correct"] += 1
|
| 470 |
+
if args_correct:
|
| 471 |
+
results["arguments_correct"] += 1
|
| 472 |
+
if exact_match:
|
| 473 |
+
results["exact_match"] += 1
|
| 474 |
+
if rejection_correct:
|
| 475 |
+
results["rejection_correct"] += 1
|
| 476 |
+
results["total_arg_score"] += arg_score
|
| 477 |
+
|
| 478 |
+
# By category
|
| 479 |
+
if category not in results["by_category"]:
|
| 480 |
+
results["by_category"][category] = {
|
| 481 |
+
"total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0
|
| 482 |
+
}
|
| 483 |
+
results["by_category"][category]["total"] += 1
|
| 484 |
+
if func_correct:
|
| 485 |
+
results["by_category"][category]["func_correct"] += 1
|
| 486 |
+
if exact_match:
|
| 487 |
+
results["by_category"][category]["exact_match"] += 1
|
| 488 |
+
results["by_category"][category]["arg_score"] += arg_score
|
| 489 |
+
|
| 490 |
+
# By function
|
| 491 |
+
func_key = expected_func or "None"
|
| 492 |
+
if func_key not in results["by_function"]:
|
| 493 |
+
results["by_function"][func_key] = {
|
| 494 |
+
"total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0
|
| 495 |
+
}
|
| 496 |
+
results["by_function"][func_key]["total"] += 1
|
| 497 |
+
if func_correct:
|
| 498 |
+
results["by_function"][func_key]["func_correct"] += 1
|
| 499 |
+
if exact_match:
|
| 500 |
+
results["by_function"][func_key]["exact_match"] += 1
|
| 501 |
+
results["by_function"][func_key]["arg_score"] += arg_score
|
| 502 |
+
|
| 503 |
+
# Record errors
|
| 504 |
+
if error_msg and len(results["errors"]) < 10:
|
| 505 |
+
results["errors"].append({
|
| 506 |
+
"id": sample_id,
|
| 507 |
+
"category": category,
|
| 508 |
+
"input": result["user_input"],
|
| 509 |
+
"expected_func": expected_func,
|
| 510 |
+
"actual_func": actual_func,
|
| 511 |
+
"expected_args": result["expected_args"],
|
| 512 |
+
"actual_args": result["actual_args"],
|
| 513 |
+
"error": error_msg,
|
| 514 |
+
"response": result["response"][:200],
|
| 515 |
+
})
|
| 516 |
+
|
| 517 |
+
if verbose:
|
| 518 |
+
status = "✓" if exact_match else "✗"
|
| 519 |
+
# Extract user message preview for logs
|
| 520 |
+
user_input = result["user_input"]
|
| 521 |
+
if isinstance(user_input, dict):
|
| 522 |
+
user_msg = ""
|
| 523 |
+
if "messages" in user_input:
|
| 524 |
+
for msg in user_input["messages"]:
|
| 525 |
+
if msg.get("role") == "user":
|
| 526 |
+
user_msg = msg.get("content", "")
|
| 527 |
+
break
|
| 528 |
+
input_preview = user_msg[:50] if user_msg else str(user_input)[:50]
|
| 529 |
+
else:
|
| 530 |
+
input_preview = str(user_input)[:50]
|
| 531 |
+
logger.info(f"[{sample_id}] {status} {category}: {input_preview}...")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def print_report(results: Dict):
|
| 535 |
+
"""Print evaluation report."""
|
| 536 |
+
total = results["total"]
|
| 537 |
+
|
| 538 |
+
print("\n" + "=" * 70)
|
| 539 |
+
print("FunctionGemma Evaluation Report")
|
| 540 |
+
print("=" * 70)
|
| 541 |
+
print(f"\nTotal samples: {total}")
|
| 542 |
+
|
| 543 |
+
print("\n" + "-" * 70)
|
| 544 |
+
print("Overall metrics")
|
| 545 |
+
print("-" * 70)
|
| 546 |
+
|
| 547 |
+
func_acc = results["function_correct"] / total * 100 if total > 0 else 0
|
| 548 |
+
arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0
|
| 549 |
+
exact_acc = results["exact_match"] / total * 100 if total > 0 else 0
|
| 550 |
+
avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0
|
| 551 |
+
|
| 552 |
+
# Rejection accuracy
|
| 553 |
+
rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f))
|
| 554 |
+
rejection_total = results["by_function"].get("None", {}).get("total", 0)
|
| 555 |
+
rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0
|
| 556 |
+
|
| 557 |
+
print(f"Function selection accuracy: {func_acc:.2f}%")
|
| 558 |
+
print(f"Argument accuracy: {arg_acc:.2f}%")
|
| 559 |
+
print(f"Exact match accuracy: {exact_acc:.2f}%")
|
| 560 |
+
print(f"Average argument score: {avg_arg_score:.2f}%")
|
| 561 |
+
print(f"Rejection accuracy: {rejection_acc:.2f}%")
|
| 562 |
+
|
| 563 |
+
print("\n" + "-" * 70)
|
| 564 |
+
print("By function")
|
| 565 |
+
print("-" * 70)
|
| 566 |
+
|
| 567 |
+
for func, stats in sorted(results["by_function"].items()):
|
| 568 |
+
func_total = stats["total"]
|
| 569 |
+
func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0
|
| 570 |
+
func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0
|
| 571 |
+
func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0
|
| 572 |
+
|
| 573 |
+
print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | "
|
| 574 |
+
f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%")
|
| 575 |
+
|
| 576 |
+
if results["errors"]:
|
| 577 |
+
print("\n" + "-" * 70)
|
| 578 |
+
print("Error samples")
|
| 579 |
+
print("-" * 70)
|
| 580 |
+
|
| 581 |
+
for err in results["errors"][:5]:
|
| 582 |
+
print(f"\nID: {err['id']} | category: {err['category']}")
|
| 583 |
+
print(f"Input: {err['input']}")
|
| 584 |
+
print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}")
|
| 585 |
+
print(f"Error: {err['error']}")
|
| 586 |
+
|
| 587 |
+
print("\n" + "=" * 70)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def main():
|
| 591 |
+
parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)")
|
| 592 |
+
parser.add_argument("--model_path", type=str, required=True, help="Model path")
|
| 593 |
+
parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path")
|
| 594 |
+
parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path")
|
| 595 |
+
parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)")
|
| 596 |
+
parser.add_argument("--chain", type=str, default="solana", help="Chain name")
|
| 597 |
+
parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization")
|
| 598 |
+
parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization")
|
| 599 |
+
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
| 600 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)")
|
| 601 |
+
args = parser.parse_args()
|
| 602 |
+
|
| 603 |
+
# Load model
|
| 604 |
+
model, tokenizer = load_model(
|
| 605 |
+
args.model_path,
|
| 606 |
+
lora_path=args.lora_path,
|
| 607 |
+
load_in_8bit=args.load_in_8bit,
|
| 608 |
+
load_in_4bit=args.load_in_4bit,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Load benchmark
|
| 612 |
+
benchmark_path = Path(args.benchmark_path)
|
| 613 |
+
logger.info(f"Loading benchmark: {benchmark_path}")
|
| 614 |
+
with open(benchmark_path, 'r', encoding='utf-8') as f:
|
| 615 |
+
benchmark = json.load(f)
|
| 616 |
+
|
| 617 |
+
logger.info(f"Benchmark samples: {len(benchmark)}")
|
| 618 |
+
|
| 619 |
+
# Evaluate
|
| 620 |
+
logger.info("Starting evaluation...")
|
| 621 |
+
results = evaluate_benchmark(
|
| 622 |
+
model, tokenizer, benchmark,
|
| 623 |
+
chain=args.chain,
|
| 624 |
+
verbose=args.verbose,
|
| 625 |
+
num_workers=args.num_workers,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# Print report
|
| 629 |
+
print_report(results)
|
| 630 |
+
|
| 631 |
+
# Save results
|
| 632 |
+
output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 633 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 634 |
+
|
| 635 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 636 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 637 |
+
logger.info(f"Evaluation saved to: {output_path}")
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
if __name__ == "__main__":
|
| 641 |
+
main()
|
src/generate_benchmark.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate the FunctionGemma evaluation benchmark.
|
| 4 |
+
|
| 5 |
+
Creates 100 high-quality samples to assess function-calling accuracy across:
|
| 6 |
+
- SEARCH_TOKEN calls
|
| 7 |
+
- EXECUTE_SWAP calls
|
| 8 |
+
- Incomplete requests (should ask back)
|
| 9 |
+
- Irrelevant requests (should refuse)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
import argparse
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Any, Optional
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json"
|
| 20 |
+
|
| 21 |
+
# Token info
|
| 22 |
+
TOKENS = {
|
| 23 |
+
"SOL": {"ca": "So11111111111111111111111111111111111111112", "chain": "solana"},
|
| 24 |
+
"USDC": {"ca": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", "chain": "solana"},
|
| 25 |
+
"JUP": {"ca": "JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN", "chain": "solana"},
|
| 26 |
+
"RAY": {"ca": "4k3Dyjzvzp8eMZWUXbBCjEvwSkkk59S5iCNLY3QrkX6R", "chain": "solana"},
|
| 27 |
+
"BONK": {"ca": "DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263", "chain": "solana"},
|
| 28 |
+
"WIF": {"ca": "EKpQGSJtjMFqKZ9KQanSqYXRcF8fBopzLHYxdM65zcjm", "chain": "solana"},
|
| 29 |
+
"ETH": {"ca": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", "chain": "solana"},
|
| 30 |
+
"BTC": {"ca": "9n4nbM75f5Ui33ZbPYXn59EwSgE8CGsHtAeTH5YFeJ9E", "chain": "solana"},
|
| 31 |
+
"POPCAT": {"ca": "7GCihgDB8fe6KNjn2MYtkzZcRjQy3t9GHdC8uHYmW2hr", "chain": "solana"},
|
| 32 |
+
"TRUMP": {"ca": "6p6xgHyF7AeE6TZkSmFsko444wqoP15icUSqi2jfGiPN", "chain": "solana"},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
CHAINS = ["solana", "ethereum", "bsc", "base"]
|
| 36 |
+
|
| 37 |
+
# Tool definitions
|
| 38 |
+
TOOLS = [
|
| 39 |
+
{
|
| 40 |
+
"type": "function",
|
| 41 |
+
"function": {
|
| 42 |
+
"name": "SEARCH_TOKEN",
|
| 43 |
+
"description": "search token onchain",
|
| 44 |
+
"parameters": {
|
| 45 |
+
"type": "object",
|
| 46 |
+
"properties": {
|
| 47 |
+
"symbol": {"type": ["string", "null"], "description": "Symbol of the token"},
|
| 48 |
+
"address": {"type": ["string", "null"], "description": "Contract address of the token"},
|
| 49 |
+
"chain": {"type": "string", "enum": ["solana", "ethereum", "bsc", "base"], "description": "supported chains"},
|
| 50 |
+
"keyword": {"type": ["string", "null"], "description": "keyword to search for the token"}
|
| 51 |
+
},
|
| 52 |
+
"required": []
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"type": "function",
|
| 58 |
+
"function": {
|
| 59 |
+
"name": "EXECUTE_SWAP",
|
| 60 |
+
"description": "Swap tokens on the Solana blockchain. When the user specifies 'buy <token>', the default input token is SOL. When the user specifies 'sell <token>', the default output token is SOL.",
|
| 61 |
+
"parameters": {
|
| 62 |
+
"type": "object",
|
| 63 |
+
"properties": {
|
| 64 |
+
"inputTokenSymbol": {"type": ["string", "null"], "description": "Symbol of the token to sell."},
|
| 65 |
+
"inputTokenCA": {"type": ["string", "null"], "description": "Contract address of the token to sell."},
|
| 66 |
+
"outputTokenCA": {"type": ["string", "null"], "description": "Contract address of the token to buy."},
|
| 67 |
+
"inputTokenAmount": {"type": ["string", "null"], "description": "Exact amount of the input token to swap."},
|
| 68 |
+
"inputTokenPercentage": {"type": ["number", "null"], "description": "Percentage of the input token balance to swap."},
|
| 69 |
+
"outputTokenAmount": {"type": ["string", "null"], "description": "Expected amount of the output token to receive."}
|
| 70 |
+
},
|
| 71 |
+
"required": ["inputTokenCA", "outputTokenCA", "inputTokenAmount", "inputTokenPercentage"]
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def create_benchmark_item(
|
| 79 |
+
user_input: str,
|
| 80 |
+
expected_function: Optional[str],
|
| 81 |
+
expected_args: Optional[Dict] = None,
|
| 82 |
+
category: str = "function_call",
|
| 83 |
+
description: str = ""
|
| 84 |
+
) -> Dict:
|
| 85 |
+
"""Create one benchmark sample."""
|
| 86 |
+
return {
|
| 87 |
+
"id": None, # assigned later
|
| 88 |
+
"category": category,
|
| 89 |
+
"description": description,
|
| 90 |
+
"input": {
|
| 91 |
+
"messages": [
|
| 92 |
+
{"role": "developer", "content": "You are a model that can do function calling with the following functions"},
|
| 93 |
+
{"role": "user", "content": user_input}
|
| 94 |
+
],
|
| 95 |
+
"tools": TOOLS
|
| 96 |
+
},
|
| 97 |
+
"expected": {
|
| 98 |
+
"function_name": expected_function,
|
| 99 |
+
"arguments": expected_args
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def generate_search_token_benchmarks() -> List[Dict]:
|
| 105 |
+
"""Generate SEARCH_TOKEN cases."""
|
| 106 |
+
benchmarks = []
|
| 107 |
+
|
| 108 |
+
# 1) search by symbol (English)
|
| 109 |
+
test_cases = [
|
| 110 |
+
("Search for BONK token", "BONK", "solana", None, None),
|
| 111 |
+
("Find WIF on solana", "WIF", "solana", None, None),
|
| 112 |
+
("Look up JUP token", "JUP", "solana", None, None),
|
| 113 |
+
("Search ETH on ethereum", "ETH", "ethereum", None, None),
|
| 114 |
+
("Find USDC token on base", "USDC", "base", None, None),
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
for query, symbol, chain, address, keyword in test_cases:
|
| 118 |
+
expected_args = {"symbol": symbol, "chain": chain}
|
| 119 |
+
if address:
|
| 120 |
+
expected_args["address"] = address
|
| 121 |
+
if keyword:
|
| 122 |
+
expected_args["keyword"] = keyword
|
| 123 |
+
benchmarks.append(create_benchmark_item(
|
| 124 |
+
query, "SEARCH_TOKEN", expected_args,
|
| 125 |
+
"search_by_symbol", f"Search {symbol} by symbol"
|
| 126 |
+
))
|
| 127 |
+
|
| 128 |
+
# 2) search by symbol (Chinese)
|
| 129 |
+
cn_cases = [
|
| 130 |
+
("帮我搜索 BONK 代币", "BONK", "solana"),
|
| 131 |
+
("查一下 WIF 这个币", "WIF", "solana"),
|
| 132 |
+
("找一下 JUP 代币信息", "JUP", "solana"),
|
| 133 |
+
("搜索 RAY 代币", "RAY", "solana"),
|
| 134 |
+
("查询 POPCAT 代币", "POPCAT", "solana"),
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
for query, symbol, chain in cn_cases:
|
| 138 |
+
benchmarks.append(create_benchmark_item(
|
| 139 |
+
query, "SEARCH_TOKEN", {"symbol": symbol, "chain": chain},
|
| 140 |
+
"search_by_symbol_cn", f"Search {symbol} by symbol (Chinese)"
|
| 141 |
+
))
|
| 142 |
+
|
| 143 |
+
# 3) search by address
|
| 144 |
+
for token, info in list(TOKENS.items())[:5]:
|
| 145 |
+
query = f"Search token at address {info['ca']}"
|
| 146 |
+
benchmarks.append(create_benchmark_item(
|
| 147 |
+
query, "SEARCH_TOKEN", {"address": info['ca'], "chain": info['chain']},
|
| 148 |
+
"search_by_address", f"Search {token} by address"
|
| 149 |
+
))
|
| 150 |
+
|
| 151 |
+
# 4) search by keyword
|
| 152 |
+
keyword_cases = [
|
| 153 |
+
("Search for dog themed tokens", "dog", "solana"),
|
| 154 |
+
("Find meme coins", "meme", "solana"),
|
| 155 |
+
("Look for cat tokens on base", "cat", "base"),
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
for query, keyword, chain in keyword_cases:
|
| 159 |
+
benchmarks.append(create_benchmark_item(
|
| 160 |
+
query, "SEARCH_TOKEN", {"keyword": keyword, "chain": chain},
|
| 161 |
+
"search_by_keyword", f"Search by keyword: {keyword}"
|
| 162 |
+
))
|
| 163 |
+
|
| 164 |
+
return benchmarks
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def generate_execute_swap_benchmarks() -> List[Dict]:
|
| 168 |
+
"""Generate EXECUTE_SWAP cases."""
|
| 169 |
+
benchmarks = []
|
| 170 |
+
|
| 171 |
+
# 1) buy token (fixed amount)
|
| 172 |
+
buy_cases = [
|
| 173 |
+
("Buy 1 SOL worth of BONK", "SOL", "BONK", "1", None),
|
| 174 |
+
("Purchase 5 SOL of WIF", "SOL", "WIF", "5", None),
|
| 175 |
+
("Buy 10 USDC worth of JUP", "USDC", "JUP", "10", None),
|
| 176 |
+
("I want to buy 2 SOL of RAY", "SOL", "RAY", "2", None),
|
| 177 |
+
("Get me 0.5 SOL of POPCAT", "SOL", "POPCAT", "0.5", None),
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
for query, input_token, output_token, amount, percentage in buy_cases:
|
| 181 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 182 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 183 |
+
benchmarks.append(create_benchmark_item(
|
| 184 |
+
query, "EXECUTE_SWAP",
|
| 185 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 186 |
+
"buy_with_amount", f"Buy {output_token} with {amount} {input_token}"
|
| 187 |
+
))
|
| 188 |
+
|
| 189 |
+
# 2) buy token (percentage)
|
| 190 |
+
buy_pct_cases = [
|
| 191 |
+
("Buy BONK with 50% of my SOL", "SOL", "BONK", None, 0.5),
|
| 192 |
+
("Use 30% of my USDC to buy WIF", "USDC", "WIF", None, 0.3),
|
| 193 |
+
("Spend 100% of my SOL on JUP", "SOL", "JUP", None, 1.0),
|
| 194 |
+
("Put 25% of my ETH into RAY", "ETH", "RAY", None, 0.25),
|
| 195 |
+
("Use half of my BTC to get BONK", "BTC", "BONK", None, 0.5),
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
for query, input_token, output_token, amount, percentage in buy_pct_cases:
|
| 199 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 200 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 201 |
+
benchmarks.append(create_benchmark_item(
|
| 202 |
+
query, "EXECUTE_SWAP",
|
| 203 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 204 |
+
"buy_with_percentage", f"Buy {output_token} with {int(percentage*100)}% {input_token}"
|
| 205 |
+
))
|
| 206 |
+
|
| 207 |
+
# 3) sell token (fixed amount)
|
| 208 |
+
sell_cases = [
|
| 209 |
+
("Sell 1000 BONK", "BONK", "SOL", "1000", None),
|
| 210 |
+
("Sell 500 WIF for SOL", "WIF", "SOL", "500", None),
|
| 211 |
+
("Convert 100 JUP to SOL", "JUP", "SOL", "100", None),
|
| 212 |
+
("Dump 2000 RAY", "RAY", "SOL", "2000", None),
|
| 213 |
+
("Sell 50 USDC", "USDC", "SOL", "50", None),
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
for query, input_token, output_token, amount, percentage in sell_cases:
|
| 217 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 218 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 219 |
+
benchmarks.append(create_benchmark_item(
|
| 220 |
+
query, "EXECUTE_SWAP",
|
| 221 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 222 |
+
"sell_with_amount", f"Sell {amount} {input_token}"
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
# 4) sell token (percentage)
|
| 226 |
+
sell_pct_cases = [
|
| 227 |
+
("Sell 50% of my BONK", "BONK", "SOL", None, 0.5),
|
| 228 |
+
("Dump all my WIF", "WIF", "SOL", None, 1.0),
|
| 229 |
+
("Sell 30% of my JUP holdings", "JUP", "SOL", None, 0.3),
|
| 230 |
+
("Get rid of 75% of my RAY", "RAY", "SOL", None, 0.75),
|
| 231 |
+
("Sell a quarter of my USDC", "USDC", "SOL", None, 0.25),
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
for query, input_token, output_token, amount, percentage in sell_pct_cases:
|
| 235 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 236 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 237 |
+
benchmarks.append(create_benchmark_item(
|
| 238 |
+
query, "EXECUTE_SWAP",
|
| 239 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 240 |
+
"sell_with_percentage", f"Sell {int(percentage*100)}% {input_token}"
|
| 241 |
+
))
|
| 242 |
+
|
| 243 |
+
# 5) Chinese buy/sell requests (content kept)
|
| 244 |
+
cn_swap_cases = [
|
| 245 |
+
("用 1 个 SOL 买 BONK", "SOL", "BONK", "1", None),
|
| 246 |
+
("把 50% 的 USDC 换成 WIF", "USDC", "WIF", None, 0.5),
|
| 247 |
+
("卖掉 1000 个 BONK", "BONK", "SOL", "1000", None),
|
| 248 |
+
("把所有 JUP 都卖了", "JUP", "SOL", None, 1.0),
|
| 249 |
+
("用 2 SOL 购买 RAY", "SOL", "RAY", "2", None),
|
| 250 |
+
("出售 30% 的 WIF", "WIF", "SOL", None, 0.3),
|
| 251 |
+
("买入 5 SOL 的 POPCAT", "SOL", "POPCAT", "5", None),
|
| 252 |
+
("清仓 ETH", "ETH", "SOL", None, 1.0),
|
| 253 |
+
]
|
| 254 |
+
|
| 255 |
+
for query, input_token, output_token, amount, percentage in cn_swap_cases:
|
| 256 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 257 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 258 |
+
benchmarks.append(create_benchmark_item(
|
| 259 |
+
query, "EXECUTE_SWAP",
|
| 260 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 261 |
+
"swap_chinese", f"Swap request in Chinese"
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
# 6) swap between tokens
|
| 265 |
+
swap_cases = [
|
| 266 |
+
("Swap 100 USDC for BONK", "USDC", "BONK", "100", None),
|
| 267 |
+
("Exchange 50 JUP for WIF", "JUP", "WIF", "50", None),
|
| 268 |
+
("Convert all my ETH to USDC", "ETH", "USDC", None, 1.0),
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
for query, input_token, output_token, amount, percentage in swap_cases:
|
| 272 |
+
input_ca = TOKENS[input_token]["ca"]
|
| 273 |
+
output_ca = TOKENS[output_token]["ca"]
|
| 274 |
+
benchmarks.append(create_benchmark_item(
|
| 275 |
+
query, "EXECUTE_SWAP",
|
| 276 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
|
| 277 |
+
"token_to_token", f"Swap {input_token} to {output_token}"
|
| 278 |
+
))
|
| 279 |
+
|
| 280 |
+
return benchmarks
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def generate_incomplete_benchmarks() -> List[Dict]:
|
| 284 |
+
"""Generate incomplete requests (should ask clarification)."""
|
| 285 |
+
benchmarks = []
|
| 286 |
+
|
| 287 |
+
incomplete_cases = [
|
| 288 |
+
("I want to buy some tokens", "incomplete_no_token", "Missing token name"),
|
| 289 |
+
("Sell my holdings", "incomplete_no_token", "Missing which token to sell"),
|
| 290 |
+
("Search for a token", "incomplete_no_info", "Missing token info"),
|
| 291 |
+
("Buy something", "incomplete_vague", "Too vague"),
|
| 292 |
+
("我想买币", "incomplete_cn", "Missing token (Chinese)"),
|
| 293 |
+
("帮我卖掉", "incomplete_cn", "Missing token and amount (Chinese)"),
|
| 294 |
+
("Swap tokens", "incomplete_swap", "Missing swap details"),
|
| 295 |
+
("I want to trade", "incomplete_trade", "Missing trade details"),
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
for query, category, description in incomplete_cases:
|
| 299 |
+
benchmarks.append(create_benchmark_item(
|
| 300 |
+
query, None, None, category, description
|
| 301 |
+
))
|
| 302 |
+
|
| 303 |
+
return benchmarks
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def generate_irrelevant_benchmarks() -> List[Dict]:
|
| 307 |
+
"""Generate irrelevant requests (should not call any function)."""
|
| 308 |
+
benchmarks = []
|
| 309 |
+
|
| 310 |
+
irrelevant_cases = [
|
| 311 |
+
("What's the weather today?", "irrelevant_weather", "Weather query"),
|
| 312 |
+
("Tell me a joke", "irrelevant_joke", "Joke request"),
|
| 313 |
+
("What time is it?", "irrelevant_time", "Time query"),
|
| 314 |
+
("Who is the president?", "irrelevant_general", "General knowledge"),
|
| 315 |
+
("今天天气怎么样?", "irrelevant_cn", "Weather (Chinese)"),
|
| 316 |
+
("给我讲个笑话", "irrelevant_cn", "Joke (Chinese)"),
|
| 317 |
+
("Hello, how are you?", "irrelevant_greeting", "Greeting"),
|
| 318 |
+
("What is Bitcoin?", "irrelevant_info", "Info request (no action)"),
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
for query, category, description in irrelevant_cases:
|
| 322 |
+
benchmarks.append(create_benchmark_item(
|
| 323 |
+
query, None, None, category, description
|
| 324 |
+
))
|
| 325 |
+
|
| 326 |
+
return benchmarks
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def generate_benchmark_dataset(output_path: str = str(DEFAULT_BENCHMARK_PATH)):
|
| 330 |
+
"""Generate the full benchmark dataset."""
|
| 331 |
+
|
| 332 |
+
print("=" * 60)
|
| 333 |
+
print("Generating FunctionGemma benchmark dataset")
|
| 334 |
+
print("=" * 60)
|
| 335 |
+
|
| 336 |
+
# Collect all cases
|
| 337 |
+
all_benchmarks = []
|
| 338 |
+
|
| 339 |
+
# SEARCH_TOKEN cases
|
| 340 |
+
search_benchmarks = generate_search_token_benchmarks()
|
| 341 |
+
print(f"SEARCH_TOKEN cases: {len(search_benchmarks)}")
|
| 342 |
+
all_benchmarks.extend(search_benchmarks)
|
| 343 |
+
|
| 344 |
+
# EXECUTE_SWAP cases
|
| 345 |
+
swap_benchmarks = generate_execute_swap_benchmarks()
|
| 346 |
+
print(f"EXECUTE_SWAP cases: {len(swap_benchmarks)}")
|
| 347 |
+
all_benchmarks.extend(swap_benchmarks)
|
| 348 |
+
|
| 349 |
+
# Incomplete requests
|
| 350 |
+
incomplete_benchmarks = generate_incomplete_benchmarks()
|
| 351 |
+
print(f"Incomplete request cases: {len(incomplete_benchmarks)}")
|
| 352 |
+
all_benchmarks.extend(incomplete_benchmarks)
|
| 353 |
+
|
| 354 |
+
# Irrelevant requests
|
| 355 |
+
irrelevant_benchmarks = generate_irrelevant_benchmarks()
|
| 356 |
+
print(f"Irrelevant request cases: {len(irrelevant_benchmarks)}")
|
| 357 |
+
all_benchmarks.extend(irrelevant_benchmarks)
|
| 358 |
+
|
| 359 |
+
# Pad to 100 if needed
|
| 360 |
+
while len(all_benchmarks) < 100:
|
| 361 |
+
# Add a few variants
|
| 362 |
+
extra_cases = [
|
| 363 |
+
("Buy 3 SOL of TRUMP", "SOL", "TRUMP", "3", None, "EXECUTE_SWAP"),
|
| 364 |
+
("Search for TRUMP token", "TRUMP", "solana", None, None, "SEARCH_TOKEN"),
|
| 365 |
+
]
|
| 366 |
+
for case in extra_cases:
|
| 367 |
+
if len(all_benchmarks) >= 100:
|
| 368 |
+
break
|
| 369 |
+
if case[5] == "EXECUTE_SWAP":
|
| 370 |
+
input_ca = TOKENS[case[1]]["ca"]
|
| 371 |
+
output_ca = TOKENS[case[2]]["ca"]
|
| 372 |
+
all_benchmarks.append(create_benchmark_item(
|
| 373 |
+
case[0], "EXECUTE_SWAP",
|
| 374 |
+
{"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": case[3], "inputTokenPercentage": case[4]},
|
| 375 |
+
"extra", "Extra test case"
|
| 376 |
+
))
|
| 377 |
+
else:
|
| 378 |
+
all_benchmarks.append(create_benchmark_item(
|
| 379 |
+
case[0], "SEARCH_TOKEN",
|
| 380 |
+
{"symbol": case[1], "chain": case[2]},
|
| 381 |
+
"extra", "Extra test case"
|
| 382 |
+
))
|
| 383 |
+
|
| 384 |
+
# Limit to 100
|
| 385 |
+
all_benchmarks = all_benchmarks[:100]
|
| 386 |
+
|
| 387 |
+
# Assign ids
|
| 388 |
+
for i, item in enumerate(all_benchmarks):
|
| 389 |
+
item["id"] = i + 1
|
| 390 |
+
|
| 391 |
+
# Shuffle
|
| 392 |
+
random.seed(42)
|
| 393 |
+
random.shuffle(all_benchmarks)
|
| 394 |
+
|
| 395 |
+
# Re-assign ids
|
| 396 |
+
for i, item in enumerate(all_benchmarks):
|
| 397 |
+
item["id"] = i + 1
|
| 398 |
+
|
| 399 |
+
print(f"\nTotal: {len(all_benchmarks)} cases")
|
| 400 |
+
|
| 401 |
+
# Category stats
|
| 402 |
+
categories = {}
|
| 403 |
+
for item in all_benchmarks:
|
| 404 |
+
cat = item["category"]
|
| 405 |
+
categories[cat] = categories.get(cat, 0) + 1
|
| 406 |
+
|
| 407 |
+
print("\nCategory distribution:")
|
| 408 |
+
for cat, count in sorted(categories.items()):
|
| 409 |
+
print(f" - {cat}: {count}")
|
| 410 |
+
|
| 411 |
+
# Function stats
|
| 412 |
+
func_counts = {"SEARCH_TOKEN": 0, "EXECUTE_SWAP": 0, "None": 0}
|
| 413 |
+
for item in all_benchmarks:
|
| 414 |
+
func = item["expected"]["function_name"]
|
| 415 |
+
if func:
|
| 416 |
+
func_counts[func] = func_counts.get(func, 0) + 1
|
| 417 |
+
else:
|
| 418 |
+
func_counts["None"] += 1
|
| 419 |
+
|
| 420 |
+
print("\nFunction distribution:")
|
| 421 |
+
for func, count in func_counts.items():
|
| 422 |
+
print(f" - {func}: {count}")
|
| 423 |
+
|
| 424 |
+
# Save
|
| 425 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 426 |
+
json.dump(all_benchmarks, f, ensure_ascii=False, indent=2)
|
| 427 |
+
|
| 428 |
+
print(f"\nBenchmark saved to: {output_path}")
|
| 429 |
+
|
| 430 |
+
# Show examples
|
| 431 |
+
print("\n" + "=" * 60)
|
| 432 |
+
print("Examples:")
|
| 433 |
+
print("=" * 60)
|
| 434 |
+
|
| 435 |
+
for i, item in enumerate(all_benchmarks[:3]):
|
| 436 |
+
print(f"\n--- Example {i+1} ---")
|
| 437 |
+
print(f"ID: {item['id']}")
|
| 438 |
+
print(f"Category: {item['category']}")
|
| 439 |
+
print(f"Input: {item['input']['messages'][1]['content']}")
|
| 440 |
+
print(f"Expected function: {item['expected']['function_name']}")
|
| 441 |
+
if item['expected']['arguments']:
|
| 442 |
+
print(f"Expected args: {json.dumps(item['expected']['arguments'], ensure_ascii=False)}")
|
| 443 |
+
|
| 444 |
+
return all_benchmarks
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def main():
|
| 448 |
+
parser = argparse.ArgumentParser(description="Generate FunctionGemma benchmark dataset")
|
| 449 |
+
parser.add_argument("--output", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Output file path")
|
| 450 |
+
args = parser.parse_args()
|
| 451 |
+
|
| 452 |
+
output_path = Path(args.output)
|
| 453 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 454 |
+
|
| 455 |
+
generate_benchmark_dataset(str(output_path))
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
if __name__ == "__main__":
|
| 459 |
+
main()
|
src/prepare_dataset.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data preprocessing script.
|
| 4 |
+
|
| 5 |
+
Convert the generated dataset into a format directly consumable by SFTTrainer.
|
| 6 |
+
FunctionGemma expects a specific chat template structure.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Dict, Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json"
|
| 20 |
+
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str:
|
| 24 |
+
"""Convert tool_calls into plain text (FunctionGemma format)."""
|
| 25 |
+
if not tool_calls:
|
| 26 |
+
return ""
|
| 27 |
+
|
| 28 |
+
result_parts = []
|
| 29 |
+
for tc in tool_calls:
|
| 30 |
+
func = tc.get("function", {})
|
| 31 |
+
name = func.get("name", "")
|
| 32 |
+
args = func.get("arguments", {})
|
| 33 |
+
|
| 34 |
+
# FunctionGemma format: functionName(arguments)
|
| 35 |
+
args_str = json.dumps(args, ensure_ascii=False)
|
| 36 |
+
result_parts.append(f"{name}({args_str})")
|
| 37 |
+
|
| 38 |
+
return "\n".join(result_parts)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]:
|
| 42 |
+
"""
|
| 43 |
+
Convert message format for SFTTrainer.
|
| 44 |
+
|
| 45 |
+
Input:
|
| 46 |
+
[
|
| 47 |
+
{"role": "developer", "content": "..."},
|
| 48 |
+
{"role": "user", "content": "..."},
|
| 49 |
+
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
Output:
|
| 53 |
+
[
|
| 54 |
+
{"role": "system", "content": "..."}, # developer -> system
|
| 55 |
+
{"role": "user", "content": "..."},
|
| 56 |
+
{"role": "assistant", "content": "..."} # tool_calls flattened to text
|
| 57 |
+
]
|
| 58 |
+
"""
|
| 59 |
+
converted = []
|
| 60 |
+
|
| 61 |
+
# Build tools description
|
| 62 |
+
tools_description = ""
|
| 63 |
+
if tools:
|
| 64 |
+
tools_desc_parts = []
|
| 65 |
+
for tool in tools:
|
| 66 |
+
if tool.get("type") == "function":
|
| 67 |
+
func = tool.get("function", {})
|
| 68 |
+
name = func.get("name", "")
|
| 69 |
+
desc = func.get("description", "")
|
| 70 |
+
params = func.get("parameters", {})
|
| 71 |
+
tools_desc_parts.append(f"- {name}: {desc}")
|
| 72 |
+
if tools_desc_parts:
|
| 73 |
+
tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts)
|
| 74 |
+
|
| 75 |
+
for msg in messages:
|
| 76 |
+
role = msg.get("role", "")
|
| 77 |
+
|
| 78 |
+
if role == "developer":
|
| 79 |
+
# developer -> system
|
| 80 |
+
content = msg.get("content", "")
|
| 81 |
+
if tools_description:
|
| 82 |
+
content = content + tools_description
|
| 83 |
+
converted.append({
|
| 84 |
+
"role": "system",
|
| 85 |
+
"content": content
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
elif role == "user":
|
| 89 |
+
converted.append({
|
| 90 |
+
"role": "user",
|
| 91 |
+
"content": msg.get("content", "")
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
elif role == "assistant":
|
| 95 |
+
if "tool_calls" in msg:
|
| 96 |
+
# Convert tool_calls to text
|
| 97 |
+
tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"])
|
| 98 |
+
converted.append({
|
| 99 |
+
"role": "assistant",
|
| 100 |
+
"content": tool_calls_text
|
| 101 |
+
})
|
| 102 |
+
else:
|
| 103 |
+
converted.append({
|
| 104 |
+
"role": "assistant",
|
| 105 |
+
"content": msg.get("content", "")
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
elif role == "tool":
|
| 109 |
+
# Tool response
|
| 110 |
+
converted.append({
|
| 111 |
+
"role": "tool",
|
| 112 |
+
"content": msg.get("content", "")
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
return converted
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"):
|
| 119 |
+
"""
|
| 120 |
+
Prepare dataset.
|
| 121 |
+
|
| 122 |
+
format_type:
|
| 123 |
+
- "messages": output {"messages": [...]}
|
| 124 |
+
- "text": output {"text": "..."} (flattened text)
|
| 125 |
+
"""
|
| 126 |
+
print(f"Loading dataset: {input_path}")
|
| 127 |
+
|
| 128 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 129 |
+
data = json.load(f)
|
| 130 |
+
|
| 131 |
+
print(f"Raw samples: {len(data)}")
|
| 132 |
+
|
| 133 |
+
prepared_data = []
|
| 134 |
+
|
| 135 |
+
for i, item in enumerate(data):
|
| 136 |
+
messages = item.get("messages", [])
|
| 137 |
+
tools = item.get("tools", [])
|
| 138 |
+
|
| 139 |
+
# Convert messages
|
| 140 |
+
converted_messages = convert_messages_for_sft(messages, tools)
|
| 141 |
+
|
| 142 |
+
if format_type == "messages":
|
| 143 |
+
prepared_data.append({
|
| 144 |
+
"messages": converted_messages
|
| 145 |
+
})
|
| 146 |
+
elif format_type == "text":
|
| 147 |
+
# Convert to plain text
|
| 148 |
+
text_parts = []
|
| 149 |
+
for msg in converted_messages:
|
| 150 |
+
role = msg["role"]
|
| 151 |
+
content = msg["content"]
|
| 152 |
+
if role == "system":
|
| 153 |
+
text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>")
|
| 154 |
+
elif role == "user":
|
| 155 |
+
text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
|
| 156 |
+
elif role == "assistant":
|
| 157 |
+
text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
|
| 158 |
+
|
| 159 |
+
prepared_data.append({
|
| 160 |
+
"text": "\n".join(text_parts)
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
print(f"Processed samples: {len(prepared_data)}")
|
| 164 |
+
|
| 165 |
+
# Save
|
| 166 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 167 |
+
json.dump(prepared_data, f, ensure_ascii=False, indent=2)
|
| 168 |
+
|
| 169 |
+
print(f"Saved to: {output_path}")
|
| 170 |
+
|
| 171 |
+
# Show example
|
| 172 |
+
print("\n" + "=" * 60)
|
| 173 |
+
print("Example:")
|
| 174 |
+
print("=" * 60)
|
| 175 |
+
|
| 176 |
+
if format_type == "messages":
|
| 177 |
+
example = prepared_data[0]
|
| 178 |
+
for msg in example["messages"]:
|
| 179 |
+
print(f"\n[{msg['role']}]")
|
| 180 |
+
print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"])
|
| 181 |
+
else:
|
| 182 |
+
print(prepared_data[0]["text"][:500] + "...")
|
| 183 |
+
|
| 184 |
+
return prepared_data
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def main():
|
| 188 |
+
parser = argparse.ArgumentParser(description="Dataset preparation")
|
| 189 |
+
parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path")
|
| 190 |
+
parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path")
|
| 191 |
+
parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format")
|
| 192 |
+
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
prepare_dataset(args.input, args.output, args.format)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
main()
|
src/train.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
FunctionGemma SFT fine-tuning script.
|
| 4 |
+
|
| 5 |
+
Runs TRL SFTTrainer for FunctionGemma with two modes:
|
| 6 |
+
1) LoRA (recommended): faster, lower memory, less overfit
|
| 7 |
+
2) Full-parameter: higher cost, maximal capacity
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# LoRA (default)
|
| 11 |
+
python -m src.train \
|
| 12 |
+
--model_path /path/to/model \
|
| 13 |
+
--dataset_path ./data/training_data.json \
|
| 14 |
+
--bf16
|
| 15 |
+
|
| 16 |
+
# Full-parameter
|
| 17 |
+
python -m src.train \
|
| 18 |
+
--model_path /path/to/model \
|
| 19 |
+
--dataset_path ./data/training_data.json \
|
| 20 |
+
--no-use-lora \
|
| 21 |
+
--bf16
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import json
|
| 26 |
+
import argparse
|
| 27 |
+
import logging
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Optional
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
from datasets import Dataset, load_dataset
|
| 34 |
+
from transformers import (
|
| 35 |
+
AutoModelForCausalLM,
|
| 36 |
+
AutoTokenizer,
|
| 37 |
+
TrainingArguments,
|
| 38 |
+
BitsAndBytesConfig,
|
| 39 |
+
)
|
| 40 |
+
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
|
| 41 |
+
from trl import SFTTrainer, SFTConfig
|
| 42 |
+
|
| 43 |
+
# Paths and logging
|
| 44 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 45 |
+
DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json"
|
| 46 |
+
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs"
|
| 47 |
+
|
| 48 |
+
logging.basicConfig(
|
| 49 |
+
level=logging.INFO,
|
| 50 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 51 |
+
)
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_args():
|
| 56 |
+
"""Parse CLI arguments."""
|
| 57 |
+
parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)")
|
| 58 |
+
|
| 59 |
+
# Model
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--model_path",
|
| 62 |
+
type=str,
|
| 63 |
+
default="google/functiongemma-270m-it",
|
| 64 |
+
help="Model path or HF model id"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--tokenizer_path",
|
| 68 |
+
type=str,
|
| 69 |
+
default=None,
|
| 70 |
+
help="Tokenizer path (defaults to model_path)"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Dataset
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--dataset_path",
|
| 76 |
+
type=str,
|
| 77 |
+
default=str(DEFAULT_DATA_PATH),
|
| 78 |
+
help="Training dataset path"
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--val_split",
|
| 82 |
+
type=float,
|
| 83 |
+
default=0.1,
|
| 84 |
+
help="Validation split ratio"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Output
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--output_dir",
|
| 90 |
+
type=str,
|
| 91 |
+
default=str(DEFAULT_OUTPUT_DIR),
|
| 92 |
+
help="Root output directory"
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--run_name",
|
| 96 |
+
type=str,
|
| 97 |
+
default=None,
|
| 98 |
+
help="Run name for logging and saving"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Fine-tuning mode
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--use_lora",
|
| 104 |
+
action="store_true",
|
| 105 |
+
default=True,
|
| 106 |
+
help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune"
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune")
|
| 109 |
+
|
| 110 |
+
# LoRA (only when use_lora=True)
|
| 111 |
+
parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank")
|
| 112 |
+
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
|
| 113 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--target_modules",
|
| 116 |
+
type=str,
|
| 117 |
+
nargs="+",
|
| 118 |
+
default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 119 |
+
help="Target modules for LoRA"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Training (aligned with FunctionGemma guidance)
|
| 123 |
+
parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)")
|
| 124 |
+
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)")
|
| 125 |
+
parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device")
|
| 126 |
+
parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size")
|
| 127 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps")
|
| 128 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
|
| 129 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
|
| 130 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)")
|
| 131 |
+
parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)")
|
| 132 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)")
|
| 133 |
+
|
| 134 |
+
# Precision & optimization
|
| 135 |
+
parser.add_argument("--bf16", action="store_true", help="Use BF16")
|
| 136 |
+
parser.add_argument("--fp16", action="store_true", help="Use FP16")
|
| 137 |
+
parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)")
|
| 138 |
+
parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant")
|
| 139 |
+
parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2")
|
| 140 |
+
parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing")
|
| 141 |
+
|
| 142 |
+
# Logging & saving
|
| 143 |
+
parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps")
|
| 144 |
+
parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps")
|
| 145 |
+
parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps")
|
| 146 |
+
parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep")
|
| 147 |
+
|
| 148 |
+
# Misc
|
| 149 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 150 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
|
| 151 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub")
|
| 152 |
+
parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id")
|
| 153 |
+
|
| 154 |
+
return parser.parse_args()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1):
|
| 158 |
+
"""Load and normalize dataset structure for SFT."""
|
| 159 |
+
logger.info(f"Loading dataset: {dataset_path}")
|
| 160 |
+
|
| 161 |
+
# Load JSON dataset
|
| 162 |
+
with open(dataset_path, 'r', encoding='utf-8') as f:
|
| 163 |
+
data = json.load(f)
|
| 164 |
+
|
| 165 |
+
logger.info(f"Dataset size: {len(data)} samples")
|
| 166 |
+
|
| 167 |
+
# Normalize nested structures:
|
| 168 |
+
# if an item has input.messages/tools, lift them to top-level
|
| 169 |
+
processed_data = []
|
| 170 |
+
for idx, item in enumerate(data):
|
| 171 |
+
if 'input' in item and 'messages' in item['input']:
|
| 172 |
+
# Deep copy messages to avoid mutating original
|
| 173 |
+
messages = json.loads(json.dumps(item['input']['messages']))
|
| 174 |
+
|
| 175 |
+
# Fix tool_calls formatting if present
|
| 176 |
+
for msg in messages:
|
| 177 |
+
if 'tool_calls' in msg and msg['tool_calls']:
|
| 178 |
+
for tc in msg['tool_calls']:
|
| 179 |
+
if 'function' in tc and 'arguments' in tc['function']:
|
| 180 |
+
args = tc['function']['arguments']
|
| 181 |
+
# ensure arguments is a string
|
| 182 |
+
if not isinstance(args, str):
|
| 183 |
+
tc['function']['arguments'] = json.dumps(args)
|
| 184 |
+
|
| 185 |
+
# Convert expected field into assistant response if present
|
| 186 |
+
if 'expected' in item and item['expected']:
|
| 187 |
+
expected = item['expected']
|
| 188 |
+
# If last message is not assistant, append one
|
| 189 |
+
if messages[-1]['role'] != 'assistant':
|
| 190 |
+
# Decide between function call or refusal
|
| 191 |
+
function_name = expected.get('function_name')
|
| 192 |
+
arguments = expected.get('arguments')
|
| 193 |
+
response = expected.get('response', '')
|
| 194 |
+
|
| 195 |
+
if function_name is not None and arguments is not None:
|
| 196 |
+
# Case 1: function call -> add tool_calls
|
| 197 |
+
arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
|
| 198 |
+
|
| 199 |
+
assistant_msg = {
|
| 200 |
+
"role": "assistant",
|
| 201 |
+
"content": None,
|
| 202 |
+
"tool_calls": [{
|
| 203 |
+
"id": f"call_{hash(function_name + arguments_str) % 1000000}", # generate unique id
|
| 204 |
+
"type": "function",
|
| 205 |
+
"function": {
|
| 206 |
+
"name": function_name,
|
| 207 |
+
"arguments": arguments_str
|
| 208 |
+
}
|
| 209 |
+
}]
|
| 210 |
+
}
|
| 211 |
+
messages.append(assistant_msg)
|
| 212 |
+
logger.debug(f"Added assistant tool_calls: {function_name}")
|
| 213 |
+
elif function_name is None and arguments is None and response:
|
| 214 |
+
# Case 2: refusal -> plain text response
|
| 215 |
+
assistant_msg = {
|
| 216 |
+
"role": "assistant",
|
| 217 |
+
"content": response
|
| 218 |
+
}
|
| 219 |
+
messages.append(assistant_msg)
|
| 220 |
+
logger.debug(f"Added assistant refusal response: {response[:50]}")
|
| 221 |
+
else:
|
| 222 |
+
logger.warning(f"Unknown expected format: {expected}")
|
| 223 |
+
|
| 224 |
+
processed_item = {
|
| 225 |
+
'messages': messages
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
# include tools if present
|
| 229 |
+
if 'tools' in item['input']:
|
| 230 |
+
processed_item['tools'] = item['input']['tools']
|
| 231 |
+
|
| 232 |
+
# preserve id
|
| 233 |
+
if 'id' in item:
|
| 234 |
+
processed_item['id'] = item['id']
|
| 235 |
+
|
| 236 |
+
# Final check: tool_calls arguments must be strings
|
| 237 |
+
for msg in processed_item['messages']:
|
| 238 |
+
if 'tool_calls' in msg and msg['tool_calls']:
|
| 239 |
+
for tc in msg['tool_calls']:
|
| 240 |
+
if 'function' in tc and 'arguments' in tc['function']:
|
| 241 |
+
if not isinstance(tc['function']['arguments'], str):
|
| 242 |
+
logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}")
|
| 243 |
+
tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
|
| 244 |
+
|
| 245 |
+
processed_data.append(processed_item)
|
| 246 |
+
|
| 247 |
+
elif 'messages' in item:
|
| 248 |
+
# Already proper format, just normalize tool_calls
|
| 249 |
+
messages = json.loads(json.dumps(item['messages']))
|
| 250 |
+
for msg in messages:
|
| 251 |
+
if 'tool_calls' in msg and msg['tool_calls']:
|
| 252 |
+
for tc in msg['tool_calls']:
|
| 253 |
+
if 'function' in tc and 'arguments' in tc['function']:
|
| 254 |
+
if not isinstance(tc['function']['arguments'], str):
|
| 255 |
+
tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
|
| 256 |
+
item_copy = dict(item)
|
| 257 |
+
item_copy['messages'] = messages
|
| 258 |
+
processed_data.append(item_copy)
|
| 259 |
+
else:
|
| 260 |
+
logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}")
|
| 261 |
+
|
| 262 |
+
logger.info(f"Processed dataset size: {len(processed_data)}")
|
| 263 |
+
|
| 264 |
+
# Validate format
|
| 265 |
+
tool_calls_count = 0
|
| 266 |
+
for item in processed_data:
|
| 267 |
+
for msg in item['messages']:
|
| 268 |
+
if 'tool_calls' in msg and msg['tool_calls']:
|
| 269 |
+
tool_calls_count += 1
|
| 270 |
+
for tc in msg['tool_calls']:
|
| 271 |
+
if 'function' in tc and 'arguments' in tc['function']:
|
| 272 |
+
if not isinstance(tc['function']['arguments'], str):
|
| 273 |
+
logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}")
|
| 274 |
+
logger.info(f"Messages containing tool_calls: {tool_calls_count}")
|
| 275 |
+
|
| 276 |
+
# Convert to Hugging Face Dataset
|
| 277 |
+
dataset = Dataset.from_list(processed_data)
|
| 278 |
+
|
| 279 |
+
# Split train/val
|
| 280 |
+
if val_split > 0:
|
| 281 |
+
dataset = dataset.train_test_split(test_size=val_split, seed=42)
|
| 282 |
+
train_dataset = dataset['train']
|
| 283 |
+
eval_dataset = dataset['test']
|
| 284 |
+
logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
| 285 |
+
else:
|
| 286 |
+
train_dataset = dataset
|
| 287 |
+
eval_dataset = None
|
| 288 |
+
logger.info(f"Train: {len(train_dataset)}, no eval split")
|
| 289 |
+
|
| 290 |
+
return train_dataset, eval_dataset
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_quantization_config(use_4bit: bool, use_8bit: bool):
|
| 294 |
+
"""Build quantization config if requested."""
|
| 295 |
+
if use_4bit:
|
| 296 |
+
logger.info("Using 4-bit quantization (QLoRA)")
|
| 297 |
+
return BitsAndBytesConfig(
|
| 298 |
+
load_in_4bit=True,
|
| 299 |
+
bnb_4bit_quant_type="nf4",
|
| 300 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 301 |
+
bnb_4bit_use_double_quant=True,
|
| 302 |
+
)
|
| 303 |
+
elif use_8bit:
|
| 304 |
+
logger.info("Using 8-bit quantization")
|
| 305 |
+
return BitsAndBytesConfig(
|
| 306 |
+
load_in_8bit=True,
|
| 307 |
+
)
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def load_model_and_tokenizer(args):
|
| 312 |
+
"""Load model and tokenizer."""
|
| 313 |
+
logger.info(f"Loading model: {args.model_path}")
|
| 314 |
+
|
| 315 |
+
tokenizer_path = args.tokenizer_path or args.model_path
|
| 316 |
+
|
| 317 |
+
# Load tokenizer
|
| 318 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 319 |
+
tokenizer_path,
|
| 320 |
+
trust_remote_code=True,
|
| 321 |
+
padding_side="right",
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Ensure pad token exists
|
| 325 |
+
if tokenizer.pad_token is None:
|
| 326 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 327 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 328 |
+
|
| 329 |
+
# Quantization config
|
| 330 |
+
quantization_config = get_quantization_config(args.use_4bit, args.use_8bit)
|
| 331 |
+
|
| 332 |
+
# Model kwargs
|
| 333 |
+
model_kwargs = {
|
| 334 |
+
"trust_remote_code": True,
|
| 335 |
+
"device_map": "auto",
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
if quantization_config:
|
| 339 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 340 |
+
|
| 341 |
+
# Precision
|
| 342 |
+
if args.bf16 and not (args.use_4bit or args.use_8bit):
|
| 343 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 344 |
+
elif args.fp16 and not (args.use_4bit or args.use_8bit):
|
| 345 |
+
model_kwargs["torch_dtype"] = torch.float16
|
| 346 |
+
|
| 347 |
+
# Flash Attention
|
| 348 |
+
if args.use_flash_attention:
|
| 349 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 350 |
+
logger.info("Using Flash Attention 2")
|
| 351 |
+
|
| 352 |
+
# Load model
|
| 353 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 354 |
+
args.model_path,
|
| 355 |
+
**model_kwargs
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Prepare for k-bit training when quantized
|
| 359 |
+
if args.use_4bit or args.use_8bit:
|
| 360 |
+
model = prepare_model_for_kbit_training(model)
|
| 361 |
+
|
| 362 |
+
# Gradient checkpointing
|
| 363 |
+
if args.gradient_checkpointing:
|
| 364 |
+
model.gradient_checkpointing_enable()
|
| 365 |
+
logger.info("Enabled gradient checkpointing")
|
| 366 |
+
|
| 367 |
+
logger.info(f"Model parameters: {model.num_parameters():,}")
|
| 368 |
+
|
| 369 |
+
return model, tokenizer
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_lora_config(args):
|
| 373 |
+
"""Build LoRA config."""
|
| 374 |
+
logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}")
|
| 375 |
+
logger.info(f"Target modules: {args.target_modules}")
|
| 376 |
+
|
| 377 |
+
return LoraConfig(
|
| 378 |
+
r=args.lora_r,
|
| 379 |
+
lora_alpha=args.lora_alpha,
|
| 380 |
+
lora_dropout=args.lora_dropout,
|
| 381 |
+
target_modules=args.target_modules,
|
| 382 |
+
bias="none",
|
| 383 |
+
task_type=TaskType.CAUSAL_LM,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def formatting_func(example):
|
| 388 |
+
"""
|
| 389 |
+
Format function: pass data through for SFTTrainer.
|
| 390 |
+
|
| 391 |
+
Dataset format:
|
| 392 |
+
{
|
| 393 |
+
"messages": [
|
| 394 |
+
{"role": "developer", "content": "..."},
|
| 395 |
+
{"role": "user", "content": "..."},
|
| 396 |
+
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
|
| 397 |
+
],
|
| 398 |
+
"tools": [...]
|
| 399 |
+
}
|
| 400 |
+
"""
|
| 401 |
+
# Return as-is; SFTTrainer applies chat template
|
| 402 |
+
return example
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def main():
|
| 406 |
+
args = parse_args()
|
| 407 |
+
|
| 408 |
+
# Set run name
|
| 409 |
+
if args.run_name is None:
|
| 410 |
+
args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 411 |
+
|
| 412 |
+
# Create output directory
|
| 413 |
+
output_dir = os.path.join(args.output_dir, args.run_name)
|
| 414 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 415 |
+
|
| 416 |
+
logger.info("=" * 60)
|
| 417 |
+
logger.info("FunctionGemma SFT LoRA training")
|
| 418 |
+
logger.info("=" * 60)
|
| 419 |
+
logger.info(f"Output dir: {output_dir}")
|
| 420 |
+
|
| 421 |
+
# Save config
|
| 422 |
+
config_path = os.path.join(output_dir, "training_config.json")
|
| 423 |
+
with open(config_path, 'w') as f:
|
| 424 |
+
json.dump(vars(args), f, indent=2)
|
| 425 |
+
logger.info(f"Config saved to: {config_path}")
|
| 426 |
+
|
| 427 |
+
# Load dataset
|
| 428 |
+
train_dataset, eval_dataset = load_and_prepare_dataset(
|
| 429 |
+
args.dataset_path,
|
| 430 |
+
args.val_split
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Load model + tokenizer
|
| 434 |
+
model, tokenizer = load_model_and_tokenizer(args)
|
| 435 |
+
|
| 436 |
+
# Build LoRA config if enabled
|
| 437 |
+
if args.use_lora:
|
| 438 |
+
logger.info("=" * 60)
|
| 439 |
+
logger.info("LoRA fine-tuning mode")
|
| 440 |
+
logger.info("=" * 60)
|
| 441 |
+
lora_config = get_lora_config(args)
|
| 442 |
+
else:
|
| 443 |
+
logger.info("=" * 60)
|
| 444 |
+
logger.info("Full-parameter fine-tuning mode")
|
| 445 |
+
logger.info("Warning: full fine-tuning needs more memory and time!")
|
| 446 |
+
logger.info("=" * 60)
|
| 447 |
+
lora_config = None
|
| 448 |
+
|
| 449 |
+
# SFTTrainer config
|
| 450 |
+
training_args = SFTConfig(
|
| 451 |
+
output_dir=output_dir,
|
| 452 |
+
run_name=args.run_name,
|
| 453 |
+
|
| 454 |
+
# Sequence length / packing
|
| 455 |
+
max_length=args.max_seq_length,
|
| 456 |
+
packing=False,
|
| 457 |
+
|
| 458 |
+
# Training
|
| 459 |
+
num_train_epochs=args.num_train_epochs,
|
| 460 |
+
max_steps=args.max_steps,
|
| 461 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 462 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
| 463 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 464 |
+
|
| 465 |
+
# Optimizer
|
| 466 |
+
learning_rate=args.learning_rate,
|
| 467 |
+
weight_decay=args.weight_decay,
|
| 468 |
+
warmup_ratio=args.warmup_ratio,
|
| 469 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
| 470 |
+
optim="adamw_torch_fused",
|
| 471 |
+
|
| 472 |
+
# Precision
|
| 473 |
+
bf16=args.bf16,
|
| 474 |
+
fp16=args.fp16,
|
| 475 |
+
|
| 476 |
+
# Logging / saving
|
| 477 |
+
logging_steps=args.logging_steps,
|
| 478 |
+
save_steps=args.save_steps,
|
| 479 |
+
eval_steps=args.eval_steps if eval_dataset else None,
|
| 480 |
+
eval_strategy="steps" if eval_dataset else "no",
|
| 481 |
+
save_total_limit=args.save_total_limit,
|
| 482 |
+
load_best_model_at_end=True if eval_dataset else False,
|
| 483 |
+
|
| 484 |
+
# Misc
|
| 485 |
+
seed=args.seed,
|
| 486 |
+
report_to=["tensorboard"],
|
| 487 |
+
|
| 488 |
+
# Hub
|
| 489 |
+
push_to_hub=args.push_to_hub,
|
| 490 |
+
hub_model_id=args.hub_model_id,
|
| 491 |
+
|
| 492 |
+
# Gradient checkpointing
|
| 493 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
| 494 |
+
gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Create SFTTrainer
|
| 498 |
+
# Dataset should include 'messages' and 'tools'; SFTTrainer applies chat template automatically
|
| 499 |
+
trainer = SFTTrainer(
|
| 500 |
+
model=model,
|
| 501 |
+
args=training_args,
|
| 502 |
+
train_dataset=train_dataset,
|
| 503 |
+
eval_dataset=eval_dataset,
|
| 504 |
+
processing_class=tokenizer, # newer TRL uses processing_class instead of tokenizer
|
| 505 |
+
peft_config=lora_config,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Parameter stats
|
| 509 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 510 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 511 |
+
trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0
|
| 512 |
+
|
| 513 |
+
logger.info("=" * 60)
|
| 514 |
+
logger.info("Model parameter stats:")
|
| 515 |
+
logger.info(f" Total params: {total_params:,}")
|
| 516 |
+
logger.info(f" Trainable params: {trainable_params:,}")
|
| 517 |
+
logger.info(f" Trainable ratio: {trainable_percentage:.2f}%")
|
| 518 |
+
logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}")
|
| 519 |
+
logger.info("=" * 60)
|
| 520 |
+
|
| 521 |
+
# Train
|
| 522 |
+
logger.info("Start training...")
|
| 523 |
+
|
| 524 |
+
if args.resume_from_checkpoint:
|
| 525 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
| 526 |
+
else:
|
| 527 |
+
trainer.train()
|
| 528 |
+
|
| 529 |
+
# Save final model
|
| 530 |
+
logger.info("Saving final model...")
|
| 531 |
+
final_model_path = os.path.join(output_dir, "final_model")
|
| 532 |
+
trainer.save_model(final_model_path)
|
| 533 |
+
tokenizer.save_pretrained(final_model_path)
|
| 534 |
+
|
| 535 |
+
logger.info("=" * 60)
|
| 536 |
+
logger.info("Training done.")
|
| 537 |
+
logger.info(f"Model saved at: {final_model_path}")
|
| 538 |
+
|
| 539 |
+
if args.use_lora:
|
| 540 |
+
# LoRA: also save adapter
|
| 541 |
+
lora_path = os.path.join(output_dir, "lora_adapter")
|
| 542 |
+
model.save_pretrained(lora_path)
|
| 543 |
+
tokenizer.save_pretrained(lora_path)
|
| 544 |
+
logger.info(f"LoRA adapter saved to: {lora_path}")
|
| 545 |
+
logger.info("")
|
| 546 |
+
logger.info("Usage:")
|
| 547 |
+
logger.info(f" 1. LoRA adapter: {lora_path}")
|
| 548 |
+
logger.info(f" 2. Merge adapters with your base model before inference")
|
| 549 |
+
else:
|
| 550 |
+
# Full fine-tune: final_model is ready to use
|
| 551 |
+
logger.info("")
|
| 552 |
+
logger.info("Usage:")
|
| 553 |
+
logger.info(f" Use model directly from: {final_model_path}")
|
| 554 |
+
|
| 555 |
+
logger.info("=" * 60)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if __name__ == "__main__":
|
| 559 |
+
main()
|