yuzhe commited on
Commit
6f09d40
·
verified ·
1 Parent(s): 7a71438

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/training_data.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 FunctionGemma contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+
README.md CHANGED
@@ -1,4 +1,12 @@
1
- ---
 
 
 
 
 
 
 
 
2
  language:
3
  - en
4
  - zh
@@ -19,1007 +27,149 @@ tags:
19
  - standard-protocol
20
  library_name: transformers
21
  pipeline_tag: text-generation
22
- ---
23
-
24
- # DMind-3-nano: Privacy-First On-Device Crypto Intent Recognition
25
-
26
- <div align="center">
27
-
28
- [![Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-blue)](https://huggingface.co/YOUR_ORG/DMind-3-nano)
29
- [![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)
30
- [![Base Model](https://img.shields.io/badge/Base-FunctionGemma--270M-orange)](https://huggingface.co/google/functiongemma-270m-it)
31
- [![Privacy](https://img.shields.io/badge/Privacy-100%25%20On--Device-brightgreen)](https://huggingface.co/YOUR_ORG/DMind-3-nano)
32
- [![Protocol](https://img.shields.io/badge/Protocol-Standardized-blue)](https://huggingface.co/YOUR_ORG/DMind-3-nano)
33
-
34
- **🔐 Your Keys. Your Data. Your Privacy.**
35
-
36
- </div>
37
-
38
- ## 🎯 Mission: Privacy-First Local Wallet Intelligence
39
-
40
- **DMind-3-nano** is designed for **on-device intent recognition** in cryptocurrency wallets, prioritizing user privacy through local inference. It establishes standardized protocols for blockchain function calling while keeping all user interactions private and secure.
41
 
42
- ### Core Principles
43
 
44
- 🔐 **Privacy-First**: All inference happens locally—no user data leaves the device
45
- 📱 **Edge-Optimized**: 270M parameters designed for mobile and edge deployment
46
- 🔄 **Standardized**: Unified protocols enable ecosystem-wide compatibility
47
- **Fast & Efficient**: Sub-second response time on consumer hardware
48
- 🌍 **Accessible**: Runs on phones, tablets, and local hardware wallets
 
 
 
49
 
50
- ### Why On-Device Matters
51
 
52
- When handling cryptocurrency operations, **privacy is paramount**:
53
 
54
- - Cloud-based AI: Your wallet commands, token preferences, and trading patterns are exposed
55
- - ✅ Local AI: All intent recognition happens on your device—**your keys, your data, your privacy**
56
 
57
- By standardizing `SEARCH_TOKEN` and `EXECUTE_SWAP` protocols, we enable seamless integration across wallets, DEXs, and agent frameworks—all while maintaining complete user privacy.
58
 
59
- ---
 
 
 
60
 
61
- ## 📋 Model Overview
62
 
63
- **DMind-3-nano** is an ultra-lightweight function-calling model fine-tuned from `google/functiongemma-270m-it`, specifically optimized for on-device cryptocurrency intent recognition. Validated on 10,000+ real user interactions, it delivers enterprise-grade accuracy while running entirely on consumer hardware.
64
 
65
- ### At a Glance: Why Choose On-Device?
66
 
67
- | Feature | ☁️ Cloud AI Wallets | 🔐 DMind-3-nano |
68
- |---------|-------------------|-----------------|
69
- | **Your Commands** | Sent to servers | Stay on your device |
70
- | **Privacy** | Logged & analyzed | 100% private |
71
- | **Latency** | 500-2000ms | <100ms |
72
- | **Works Offline** | ❌ No | ✅ Yes |
73
- | **Monthly Cost** | $5-20/user | Free |
74
- | **Data Breach Risk** | High (centralized) | Zero (no cloud) |
75
- | **Model Size** | N/A (cloud) | 70-540MB |
76
 
77
- ### Key Features
78
 
79
- - 🔐 **Privacy-First**: 100% on-device inference—zero data sent to cloud
80
- - 📱 **Edge-Optimized**: Only 270M parameters, runs on phones and tablets
81
- - ⚡ **Fast Inference**: <100ms response time on modern mobile CPUs
82
- - 🎯 **Protocol-Standardized**: Cross-platform compatibility through unified schemas
83
- - 🌐 **Multi-Chain**: Solana, Ethereum, BSC, Base support
84
- - 🌍 **Multilingual**: Native English and Chinese understanding
85
- - 🔋 **Energy-Efficient**: Minimal battery impact for mobile wallets
86
 
87
- ### Model Details
88
 
89
  | Property | Value |
90
- |----------|-------|
91
- | Model Name | DMind-3-nano |
92
- | Base Model | google/functiongemma-270m-it |
93
- | Parameters | 270M |
94
- | Context Length | 2048 tokens |
95
- | Precision | BF16 |
96
- | Languages | English, Chinese |
97
- | License | Apache 2.0 |
98
-
99
- ### ⚠️ Experimental Model Notice
100
-
101
- **DMind-3-nano** is an **exploratory research model** demonstrating privacy-first on-device function calling. The model has been optimized and validated on a specific subset of tokens and chains commonly used in the Solana ecosystem.
102
-
103
- **Best Supported Tokens:**
104
- ```
105
- SOL, USDC, JUP, RAY, BONK, WIF, ETH, BTC, POPCAT, BOME, TRUMP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  ```
107
 
108
- **Best Supported Chains:**
109
- ```
110
- solana, ethereum, bsc, base
 
 
 
 
111
  ```
 
112
 
113
- **Performance Notes:**
114
- - ✅ **Optimal**: The model achieves highest accuracy (95%+) on the above tokens and chains
115
- - ⚠️ **General Support**: Other tokens/chains are supported but may have lower accuracy
116
- - 🔄 **Extensible**: The protocol design allows easy fine-tuning for additional tokens/chains
117
-
118
- For production deployments, we recommend:
119
- 1. Testing with your specific token list
120
- 2. Fine-tuning on your target tokens if needed
121
- 3. Validating outputs before executing transactions
122
-
123
- ---
124
-
125
- ## 🚀 Quick Start
126
-
127
- ### Installation
128
-
129
  ```bash
130
- pip install transformers>=4.45.0 torch>=2.0.0 accelerate>=0.24.0
 
 
 
131
  ```
132
 
133
- ### Basic Usage
134
-
135
- ```python
136
- import torch
137
- from transformers import AutoProcessor, AutoModelForCausalLM
138
-
139
- # Load model and processor
140
- model_name = "YOUR_ORG/DMind-3-nano"
141
- processor = AutoProcessor.from_pretrained(model_name)
142
- model = AutoModelForCausalLM.from_pretrained(
143
- model_name,
144
- torch_dtype=torch.bfloat16,
145
- device_map="auto"
146
- )
147
-
148
- # Define available tools
149
- tools = [
150
- {
151
- "type": "function",
152
- "function": {
153
- "name": "SEARCH_TOKEN",
154
- "description": "Search for tokens on blockchain",
155
- "parameters": {
156
- "type": "object",
157
- "properties": {
158
- "symbol": {"type": "string", "description": "Token symbol"},
159
- "chain": {"type": "string", "enum": ["solana", "ethereum", "bsc", "base"]}
160
- }
161
- }
162
- }
163
- },
164
- {
165
- "type": "function",
166
- "function": {
167
- "name": "EXECUTE_SWAP",
168
- "description": "Execute token swap on Solana",
169
- "parameters": {
170
- "type": "object",
171
- "properties": {
172
- "inputTokenSymbol": {"type": "string"},
173
- "outputTokenSymbol": {"type": "string"},
174
- "inputTokenAmount": {"type": "string"}
175
- }
176
- }
177
- }
178
- }
179
- ]
180
-
181
- # Prepare conversation
182
- messages = [
183
- {"role": "developer", "content": "You are a helpful assistant for crypto operations."},
184
- {"role": "user", "content": "Buy 100 BONK with SOL"}
185
- ]
186
-
187
- # Generate function call
188
- inputs = processor.apply_chat_template(
189
- messages,
190
- tools=tools,
191
- add_generation_prompt=True,
192
- return_dict=True,
193
- return_tensors="pt"
194
- ).to(model.device)
195
-
196
- outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
197
- response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
198
- print(response)
199
- # Output: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"100"}<end_function_call>
200
  ```
201
 
202
- ---
203
-
204
- ## 🛠️ Tool Protocols (Standardized)
205
 
206
- ### 1️⃣ SEARCH_TOKEN
207
 
208
- Search for cryptocurrency tokens across multiple blockchains.
 
209
 
210
- **Schema:**
 
211
 
212
- ```json
213
- {
214
- "name": "SEARCH_TOKEN",
215
- "description": "Search for tokens on blockchain by symbol, address, or keyword",
216
- "parameters": {
217
- "type": "object",
218
- "properties": {
219
- "symbol": {
220
- "type": "string",
221
- "description": "Token symbol (e.g., 'SOL', 'ETH')"
222
- },
223
- "address": {
224
- "type": "string",
225
- "description": "Contract address of the token"
226
- },
227
- "chain": {
228
- "type": "string",
229
- "enum": ["solana", "ethereum", "bsc", "base"],
230
- "description": "Target blockchain network"
231
- },
232
- "keyword": {
233
- "type": "string",
234
- "description": "Search keyword for token discovery"
235
- }
236
- },
237
- "required": []
238
- }
239
- }
240
  ```
241
 
242
- **Example Usage:**
243
-
244
  ```
245
  User: "查一下 Solana 上的 SOL"
246
  Model: <start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>
247
  ```
248
 
249
- ### 2️⃣ EXECUTE_SWAP
250
-
251
- Execute token swaps on Solana blockchain with intelligent defaults.
252
-
253
- **Schema:**
254
-
255
- ```json
256
- {
257
- "name": "EXECUTE_SWAP",
258
- "description": "Swap tokens on the Solana blockchain. When user specifies 'buy <token>', default input is SOL. When 'sell <token>', default output is SOL.",
259
- "parameters": {
260
- "type": "object",
261
- "properties": {
262
- "inputTokenSymbol": {
263
- "type": "string",
264
- "description": "Symbol of the token to sell"
265
- },
266
- "inputTokenCA": {
267
- "type": "string",
268
- "description": "Contract address of input token"
269
- },
270
- "outputTokenSymbol": {
271
- "type": "string",
272
- "description": "Symbol of the token to buy"
273
- },
274
- "outputTokenCA": {
275
- "type": "string",
276
- "description": "Contract address of output token"
277
- },
278
- "inputTokenAmount": {
279
- "type": "string",
280
- "description": "Exact amount of input token"
281
- },
282
- "inputTokenPercentage": {
283
- "type": "number",
284
- "description": "Percentage of input token balance (0.0-1.0)"
285
- },
286
- "outputTokenAmount": {
287
- "type": "string",
288
- "description": "Expected output token amount"
289
- }
290
- },
291
- "required": []
292
- }
293
- }
294
- ```
295
-
296
- **Example Usage:**
297
-
298
- ```
299
- User: "Swap 5 SOL to BONK"
300
- Model: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"5"}<end_function_call>
301
-
302
- User: "Sell 50% of my WIF"
303
- Model: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"WIF",outputTokenSymbol:"SOL",inputTokenPercentage:0.5}<end_function_call>
304
- ```
305
-
306
- ### Output Format Convention
307
-
308
- - **Wrapper Tags**: `<start_function_call>call:FUNCTION_NAME{args}<end_function_call>`
309
- - **Escape Sequences**: Use `<escape>` for quoted strings when necessary
310
- - **Argument Format**: JSON-like key-value pairs
311
-
312
- ---
313
-
314
- ## 🔐 Privacy-First Architecture
315
-
316
- ### The Privacy Problem in Crypto AI
317
-
318
- Most AI-powered wallet assistants today send your data to cloud servers:
319
-
320
- ```
321
- ❌ Cloud-based Assistant:
322
- User: "Buy 1000 BONK" → Sent to Cloud API → Privacy Risk
323
- • Your wallet commands are logged
324
- • Trading patterns analyzed by third parties
325
- • Potential data breaches expose user behavior
326
- • Latency and internet dependency
327
- ```
328
-
329
- ### Our Solution: On-Device Intelligence
330
-
331
- **DMind-3-nano** runs entirely on your device:
332
-
333
- ```
334
- ✅ On-Device Assistant:
335
- User: "Buy 1000 BONK" → Local Processing → Complete Privacy
336
- • Zero data leaves your device
337
- • No tracking, no logging, no surveillance
338
- • Works offline
339
- • Instant response (<100ms)
340
- ```
341
-
342
- ### Why Edge Deployment?
343
-
344
- | Feature | Cloud AI | DMind-3-nano (On-Device) |
345
- |---------|----------|--------------------------|
346
- | **Privacy** | ❌ Data sent to servers | ✅ 100% local processing |
347
- | **Latency** | ~500-2000ms | ✅ <100ms |
348
- | **Offline** | ❌ Requires internet | ✅ Works offline |
349
- | **Cost** | 💰 API fees per request | ✅ Free after download |
350
- | **Security** | ⚠️ API keys, data breaches | ✅ No attack surface |
351
- | **Compliance** | ⚠️ Data sovereignty issues | ✅ Full user control |
352
-
353
- ### Real-World Deployment Scenarios
354
-
355
- **📱 Mobile Wallets**: iOS/Android wallet apps with conversational interface
356
- **💻 Desktop Wallets**: Electron-based wallets like Phantom, MetaMask
357
- **🔐 Hardware Wallets**: Secure enclaves with minimal compute (via quantization)
358
- **🌐 Browser Extensions**: Chrome/Firefox wallet extensions
359
- **🖥️ Local Trading Apps**: Desktop trading terminals with AI assistance
360
-
361
- ---
362
-
363
- ## 🌟 Why Standardized Protocols Matter
364
-
365
- ### The Problem
366
-
367
- In the current crypto AI landscape, each agent system implements its own function-calling format:
368
-
369
- ```
370
- ❌ Agent A: {"action": "swap", "from": "SOL", "to": "BONK", "amount": 100}
371
- ❌ Agent B: {"tool": "trade", "input_token": "SOL", "output_token": "BONK", "qty": 100}
372
- ❌ Agent C: {"cmd": "exchange", "sell": "SOL", "buy": "BONK", "vol": 100}
373
- ```
374
-
375
- **Result**: Fragmentation, incompatibility, and integration hell.
376
-
377
- ### Our Solution
378
-
379
- **Unified protocol** with clear schema definitions:
380
-
381
- ```
382
- ✅ Standardized: <start_function_call>call:EXECUTE_SWAP{inputTokenSymbol:"SOL",outputTokenSymbol:"BONK",inputTokenAmount:"100"}<end_function_call>
383
- ```
384
-
385
- **Benefits**:
386
- - 🔄 **Plug-and-Play**: One protocol, many platforms
387
- - 🤝 **Multi-Agent Coordination**: Different agents speak the same language
388
- - 🛡️ **Type Safety**: Well-defined schemas reduce errors
389
- - 📈 **Ecosystem Growth**: Lower barrier for new integrations
390
-
391
- ---
392
-
393
- ## 📊 Performance & Reliability
394
-
395
- Validated through extensive real-world testing:
396
-
397
- | Metric | Result |
398
- |--------|--------|
399
- | Test Cases | 10,000+ real user interactions |
400
- | Function Recognition Accuracy | **96.8%** |
401
- | Parameter Extraction Accuracy | **94.2%** |
402
- | SEARCH_TOKEN Protocol Adherence | **98.1%** |
403
- | EXECUTE_SWAP Protocol Adherence | **95.3%** |
404
- | Multi-turn Conversation Success | **92.7%** |
405
-
406
- **Testing Methodology**: Production environment with real users performing actual cryptocurrency operations across Solana, Ethereum, BSC, and Base chains. All tests conducted with standardized protocol validation.
407
-
408
- **Testing Scope**: The above metrics are based on the validated token set (SOL, USDC, JUP, RAY, BONK, WIF, ETH, BTC, POPCAT, BOME, TRUMP) and supported chains (Solana, Ethereum, BSC, Base). Performance on other tokens may vary.
409
-
410
- ---
411
-
412
- ## 🔧 Advanced Usage
413
-
414
- ### Batch Processing
415
-
416
- Process multiple queries efficiently:
417
-
418
- ```python
419
- queries = [
420
- "Buy 10 SOL worth of BONK",
421
- "Search for Ethereum tokens",
422
- "Sell 50% of my WIF holdings"
423
- ]
424
-
425
- for query in queries:
426
- messages = [
427
- {"role": "developer", "content": "You are a helpful assistant for crypto operations."},
428
- {"role": "user", "content": query}
429
- ]
430
-
431
- inputs = processor.apply_chat_template(
432
- messages, tools=tools,
433
- add_generation_prompt=True,
434
- return_dict=True,
435
- return_tensors="pt"
436
- ).to(model.device)
437
-
438
- outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
439
- response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
440
- print(f"Query: {query}")
441
- print(f"Response: {response}\n")
442
- ```
443
-
444
- ### Parsing Function Calls
445
-
446
- Example parser for the standardized output format:
447
-
448
- ```python
449
- import re
450
- import json
451
-
452
- def parse_function_call(response: str) -> dict:
453
- """Parse standardized function call format"""
454
- pattern = r'<start_function_call>call:(\w+)\{(.+?)\}<end_function_call>'
455
- match = re.search(pattern, response)
456
-
457
- if match:
458
- func_name = match.group(1)
459
- args_str = match.group(2).replace('<escape>', '"')
460
-
461
- # Parse arguments
462
- args = {}
463
- for item in args_str.split(','):
464
- if ':' in item:
465
- key, value = item.split(':', 1)
466
- args[key.strip().strip('"')] = value.strip().strip('"')
467
-
468
- return {
469
- "function": func_name,
470
- "arguments": args,
471
- "valid": True
472
- }
473
-
474
- return {"valid": False, "response": response}
475
-
476
- # Usage
477
- response = '<start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>'
478
- parsed = parse_function_call(response)
479
- print(parsed)
480
- # Output: {'function': 'SEARCH_TOKEN', 'arguments': {'symbol': 'SOL', 'chain': 'solana'}, 'valid': True}
481
- ```
482
-
483
- ---
484
-
485
- ## 🏗️ Integration Guide
486
-
487
- ### For Desktop Wallet Developers
488
-
489
- Integrate DMind-3-nano into your Python-based wallet application:
490
-
491
- ```python
492
- class PrivateWalletAgent:
493
- """On-device wallet assistant with zero cloud dependency"""
494
-
495
- def __init__(self, model_name="YOUR_ORG/DMind-3-nano"):
496
- self.processor = AutoProcessor.from_pretrained(model_name)
497
- self.model = AutoModelForCausalLM.from_pretrained(
498
- model_name,
499
- torch_dtype=torch.bfloat16,
500
- device_map="cpu" # Force CPU for consistent behavior
501
- )
502
- self.tools = self._load_tools()
503
-
504
- def process_user_command(self, user_query: str) -> dict:
505
- """Process user intent locally, no network required"""
506
- messages = [
507
- {"role": "developer", "content": "You are a private wallet assistant."},
508
- {"role": "user", "content": user_query}
509
- ]
510
-
511
- inputs = self.processor.apply_chat_template(
512
- messages, tools=self.tools,
513
- add_generation_prompt=True,
514
- return_dict=True, return_tensors="pt"
515
- )
516
-
517
- # Local inference only
518
- outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=False)
519
- response = self.processor.decode(
520
- outputs[0][inputs["input_ids"].shape[1]:],
521
- skip_special_tokens=True
522
- )
523
-
524
- return self.parse_and_execute(response)
525
-
526
- def parse_and_execute(self, response: str) -> dict:
527
- parsed = parse_function_call(response)
528
-
529
- if parsed["valid"]:
530
- # Execute locally - never send to cloud
531
- if parsed["function"] == "SEARCH_TOKEN":
532
- return self.local_token_search(**parsed["arguments"])
533
- elif parsed["function"] == "EXECUTE_SWAP":
534
- return self.prepare_swap_transaction(**parsed["arguments"])
535
-
536
- return {"error": "Invalid function call"}
537
- ```
538
-
539
- ### For Mobile Wallet Developers
540
-
541
- #### iOS (Swift + CoreML)
542
-
543
- ```swift
544
- import CoreML
545
-
546
- class WalletIntentRecognizer {
547
- private var model: MLModel
548
-
549
- init() {
550
- // Convert DMind-3-nano to CoreML format first
551
- // See: https://huggingface.co/docs/transformers/serialization
552
- guard let modelURL = Bundle.main.url(forResource: "DMind3Nano", withExtension: "mlmodelc"),
553
- let model = try? MLModel(contentsOf: modelURL) else {
554
- fatalError("Failed to load model")
555
- }
556
- self.model = model
557
- }
558
-
559
- func recognizeIntent(userInput: String) -> TransactionIntent? {
560
- // Tokenize and run inference locally
561
- let prediction = try? model.prediction(from: /* input features */)
562
-
563
- // Parse standardized output
564
- if let functionCall = parseFunctionCall(prediction?.output) {
565
- return TransactionIntent(
566
- function: functionCall.name,
567
- parameters: functionCall.args
568
- )
569
- }
570
- return nil
571
- }
572
- }
573
- ```
574
-
575
- #### Android (Kotlin + ONNX Runtime)
576
-
577
- ```kotlin
578
- import ai.onnxruntime.*
579
-
580
- class WalletIntentRecognizer(context: Context) {
581
- private val ortSession: OrtSession
582
-
583
- init {
584
- val ortEnv = OrtEnvironment.getEnvironment()
585
- val modelBytes = context.assets.open("dmind3nano.onnx").readBytes()
586
- ortSession = ortEnv.createSession(modelBytes)
587
- }
588
-
589
- fun recognizeIntent(userInput: String): TransactionIntent? {
590
- // Tokenize input
591
- val inputTensor = tokenizeInput(userInput)
592
-
593
- // Run inference on-device
594
- val outputs = ortSession.run(mapOf("input_ids" to inputTensor))
595
-
596
- // Parse standardized output
597
- return parseFunctionCall(outputs)
598
- }
599
-
600
- private fun parseFunctionCall(output: OrtSession.Result): TransactionIntent? {
601
- // Parse <start_function_call>call:FUNCTION{args}<end_function_call>
602
- // Return structured intent for wallet execution
603
- }
604
- }
605
- ```
606
-
607
- ### Model Conversion Guide
608
-
609
- #### Convert to ONNX (for Android/Cross-platform)
610
-
611
- ```python
612
- from transformers import AutoProcessor, AutoModelForCausalLM
613
- import torch
614
-
615
- model = AutoModelForCausalLM.from_pretrained("YOUR_ORG/DMind-3-nano")
616
- processor = AutoProcessor.from_pretrained("YOUR_ORG/DMind-3-nano")
617
-
618
- # Export to ONNX
619
- dummy_input = processor("test", return_tensors="pt")
620
- torch.onnx.export(
621
- model,
622
- (dummy_input["input_ids"],),
623
- "dmind3nano.onnx",
624
- input_names=["input_ids"],
625
- output_names=["logits"],
626
- dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}}
627
- )
628
-
629
- # Quantize to INT8 for smaller size (~70MB)
630
- from onnxruntime.quantization import quantize_dynamic
631
- quantize_dynamic("dmind3nano.onnx", "dmind3nano_int8.onnx")
632
- ```
633
-
634
- #### Convert to CoreML (for iOS)
635
-
636
- ```python
637
- import coremltools as ct
638
- from transformers import AutoModelForCausalLM
639
-
640
- model = AutoModelForCausalLM.from_pretrained("YOUR_ORG/DMind-3-nano")
641
-
642
- # Trace model
643
- traced_model = torch.jit.trace(model, dummy_input["input_ids"])
644
-
645
- # Convert to CoreML
646
- coreml_model = ct.convert(
647
- traced_model,
648
- inputs=[ct.TensorType(shape=(1, ct.RangeDim(1, 512)), dtype=np.int32)]
649
- )
650
-
651
- coreml_model.save("DMind3Nano.mlmodel")
652
- ```
653
-
654
- ### For Protocol Adopters
655
-
656
- If you're building your own wallet or AI model, adopt these privacy-first protocols:
657
-
658
- #### Protocol Adoption Checklist
659
-
660
- 1. ✅ **Use exact schema definitions** for `SEARCH_TOKEN` and `EXECUTE_SWAP`
661
- 2. ✅ **Follow standardized output format**: `<start_function_call>call:FUNCTION_NAME{args}<end_function_call>`
662
- 3. ✅ **Validate outputs** against the protocol schemas
663
- 4. ✅ **Support on-device inference** (no cloud dependencies)
664
- 5. ✅ **Document your implementation** and share with the community
665
-
666
- #### Why Adopt These Protocols?
667
-
668
- **For Users:**
669
- - 🔐 Privacy guarantee: Works identically across all compliant wallets
670
- - 🔄 Portability: Same commands work everywhere
671
- - 🛡️ Security: Standardized validation reduces vulnerabilities
672
-
673
- **For Developers:**
674
- - ⚡ Instant compatibility with the ecosystem
675
- - 📚 Well-documented, community-validated schemas
676
- - 🚀 Reduced development time (60-80% less integration work)
677
- - 🤝 Multi-wallet collaboration without custom adapters
678
- - 🎯 Focus on UX instead of protocol design
679
-
680
- #### Reference Implementations
681
-
682
- | Platform | Language | Status | Link |
683
- |----------|----------|--------|------|
684
- | Python (Desktop) | Python | ✅ Reference | This repo |
685
- | iOS CoreML | Swift | 🔄 Community | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
686
- | Android ONNX | Kotlin | 🔄 Community | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
687
- | React Native | JavaScript | 📋 Wanted | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
688
- | Rust (embedded) | Rust | 📋 Wanted | [Contribute!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
689
-
690
- ---
691
-
692
- ## 💡 Use Cases & Examples
693
-
694
- ### Real-World Privacy-Preserving Scenarios
695
-
696
- #### 1. **Mobile Wallet with Voice Commands**
697
- ```
698
- User speaks: "Send 100 USDC to my friend"
699
- ↓ (On-device speech-to-text)
700
- DMind-3-nano processes: "Send 100 USDC..."
701
- ↓ (Local inference, <100ms)
702
- Output: <start_function_call>call:EXECUTE_SWAP{...}<end_function_call>
703
-
704
- Wallet prepares transaction locally, asks for confirmation
705
- ✅ Zero data sent to cloud
706
- ```
707
-
708
- #### 2. **Hardware Wallet with Limited Display**
709
- ```
710
- User types on companion app: "查一下ETH价格"
711
-
712
- DMind-3-nano (quantized INT8) on secure enclave
713
- ↓ (2-5s inference)
714
- Output: SEARCH_TOKEN{symbol:"ETH",chain:"ethereum"}
715
-
716
- Display ETH info on hardware wallet screen
717
- ✅ All processing in secure hardware
718
- ```
719
-
720
- #### 3. **Desktop Trading Terminal**
721
- ```
722
- Power user: "Swap 50% of my SOL portfolio to BONK and POPCAT equally"
723
-
724
- DMind-3-nano parses complex intent locally
725
- ↓ (<50ms on desktop CPU)
726
- Output: Multiple EXECUTE_SWAP calls
727
-
728
- Terminal shows batch transaction preview
729
- ✅ Trading strategy never leaves local machine
730
- ```
731
-
732
- #### 4. **Offline Cold Wallet Management**
733
- ```
734
- User in air-gapped environment: "Prepare multi-sig transaction for 1000 USDT"
735
-
736
- DMind-3-nano (offline mode)
737
-
738
- Generates transaction template locally
739
-
740
- Sign and export to USB for broadcasting later
741
- ✅ Complete airgap maintained
742
- ```
743
-
744
- ### Privacy Comparison Table
745
-
746
- | Scenario | Cloud AI | DMind-3-nano |
747
- |----------|----------|--------------|
748
- | User says "Buy 1000 PEPE" | 🔴 Command logged by API provider | ✅ Processed locally, no logs |
749
- | Frequent trading patterns | 🔴 Profile built for ads/analytics | ✅ Private, no tracking |
750
- | Whale wallet detected | 🔴 High-value user targeted | ✅ Anonymous to all |
751
- | Internet outage | 🔴 Assistant unusable | ✅ Full functionality |
752
- | Data breach at AI company | 🔴 Your trading history exposed | ✅ No data to breach |
753
- | Government data request | 🔴 Provider must comply | ✅ No centralized data exists |
754
-
755
- ---
756
-
757
- ## 🔄 System Requirements
758
-
759
- ### Software Dependencies
760
-
761
- | Package | Minimum Version |
762
- |---------|-----------------|
763
- | transformers | ≥ 4.45.0 |
764
- | torch | ≥ 2.0.0 |
765
- | accelerate | ≥ 0.24.0 |
766
- | Python | ≥ 3.8 |
767
-
768
- ### Hardware Recommendations
769
-
770
- #### 🖥️ Desktop/Server (Development & Testing)
771
-
772
- | Configuration | Spec |
773
- |---------------|------|
774
- | **GPU** (Recommended) | NVIDIA GPU with ≥8GB VRAM |
775
- | **CPU** | Modern multi-core processor |
776
- | **RAM** | ≥16GB |
777
- | **Storage** | ~1GB for model files |
778
-
779
- **Performance Notes:**
780
- - ✅ BF16 precision: Best performance on Ampere+ GPUs (RTX 3000+, A100, H100)
781
- - ✅ FP16 precision: Good performance on older GPUs (V100, P100)
782
- - ✅ CPU inference: ~100-500ms latency on modern CPUs
783
-
784
- #### 📱 Mobile/Edge Devices (Production Deployment)
785
-
786
- | Device Type | Requirements | Expected Performance |
787
- |-------------|--------------|---------------------|
788
- | **iPhone** | iPhone 12+ (A14+) | <100ms with CoreML |
789
- | **Android** | Snapdragon 888+ or equivalent | <150ms with ONNX Runtime |
790
- | **iPad/Tablet** | Apple M1+ or flagship Android | <80ms |
791
- | **Embedded** | Raspberry Pi 4 (8GB) | ~1-2s (quantized) |
792
- | **Hardware Wallets** | ARM Cortex-A (high-end) | ~2-5s (INT8 quantized) |
793
-
794
- **Mobile Optimization Tips:**
795
- - 🔹 Use **ONNX** or **CoreML** conversion for optimal mobile performance
796
- - 🔹 Apply **INT8 quantization** to reduce model size to ~70MB
797
- - 🔹 Enable **neural engine acceleration** on Apple devices
798
- - 🔹 Use **NNAPI** on Android for hardware acceleration
799
-
800
- ---
801
-
802
- ## 🤝 Join the Privacy-First Standardization Movement
803
-
804
- We're building an **open protocol for private, on-device crypto AI**. Help us establish industry standards that prioritize user privacy!
805
-
806
- ### How to Contribute
807
-
808
- **🔐 For Wallet Developers:**
809
- - Integrate DMind-3-nano into your mobile/desktop wallet
810
- - Share mobile deployment challenges and solutions
811
- - Contribute platform-specific optimizations (iOS, Android, etc.)
812
- - Report real-world privacy and performance metrics
813
-
814
- **🎯 For Model Developers:**
815
- - Train specialized models using our protocols
816
- - Share quantization and optimization techniques
817
- - Contribute to model compression research (INT8, pruning, etc.)
818
- - Benchmark on different edge devices
819
-
820
- **🛠️ For Edge AI Engineers:**
821
- - Optimize for specific hardware (Apple Neural Engine, NNAPI, etc.)
822
- - Port to new platforms (Rust, WebAssembly, etc.)
823
- - Improve inference speed and energy efficiency
824
- - Share deployment best practices
825
-
826
- **📖 For Protocol Designers:**
827
- - Review and critique schema definitions
828
- - Propose new privacy-preserving protocols
829
- - Design protocols for emerging use cases (DeFi, NFTs, DAOs)
830
- - Help establish governance and versioning
831
-
832
- **📣 Spread Privacy-First AI:**
833
- - Star this repository and share with wallet developers
834
- - Write about the importance of on-device inference
835
- - Educate users about privacy risks of cloud-based AI
836
- - Organize workshops on edge AI for crypto
837
-
838
- ### Discussion Channels
839
-
840
- - 💬 [GitHub Discussions](https://github.com/YOUR_ORG/dmind-3-nano/discussions): Protocol evolution and technical discussions
841
- - 🐛 [GitHub Issues](https://github.com/YOUR_ORG/dmind-3-nano/issues): Bug reports and feature requests
842
- - 📧 [Email](mailto:your-email@example.com): Partnership and collaboration inquiries
843
-
844
- ---
845
-
846
- ## ❓ Frequently Asked Questions
847
-
848
- ### General
849
-
850
- **Q: Is my data really private?**
851
- A: Yes. DMind-3-nano runs 100% on your device. No data is sent to any server, ever. You can verify this by monitoring network traffic or running in airplane mode.
852
-
853
- **Q: How fast is inference on mobile devices?**
854
- A: On modern smartphones (iPhone 12+, Snapdragon 888+), expect <100ms for intent recognition. On older devices, 200-500ms is typical.
855
-
856
- **Q: Can I use this offline?**
857
- A: Absolutely. Once downloaded, DMind-3-nano works completely offline. Perfect for cold wallets and air-gapped setups.
858
-
859
- **Q: What's the model size?**
860
- A: ~540MB (BF16), ~270MB (FP16), ~70MB (INT8 quantized). The INT8 version fits on most hardware wallets.
861
-
862
- ### Technical
863
-
864
- **Q: How do I convert to mobile formats?**
865
- A: See the [Model Conversion Guide](#model-conversion-guide) above for ONNX and CoreML conversion instructions.
866
-
867
- **Q: Does it support custom tokens?**
868
- A: Yes! The protocols are designed to handle any token symbol, address, or chain. No retraining needed for new tokens.
869
-
870
- **Q: Can I fine-tune for my specific use case?**
871
- A: Yes. While the base model works well for general crypto operations, you can fine-tune on your specific vocabulary or additional protocols.
872
-
873
- **Q: What about languages beyond English/Chinese?**
874
- A: The model has basic understanding of other languages but was primarily trained on EN/ZH. Community contributions for multilingual models are welcome!
875
-
876
- ### Integration
877
-
878
- **Q: Which mobile frameworks are supported?**
879
- A: Native iOS (CoreML), Native Android (ONNX Runtime), React Native (via ONNX), and Flutter (via TFLite) are all possible. See examples above.
880
-
881
- **Q: Can I integrate with existing wallet backends?**
882
- A: Yes! DMind-3-nano only handles intent recognition. Your existing transaction signing, RPC calls, and security logic remain unchanged.
883
-
884
- **Q: What about hardware wallet integration?**
885
- A: For secure elements with limited compute, use the INT8 quantized version. Expect 2-5s inference, which is acceptable for most use cases.
886
-
887
- ### Privacy & Security
888
-
889
- **Q: How do you prove no data is collected?**
890
- A:
891
- 1. Code is open source (inspect for network calls)
892
- 2. Model runs locally (verify via network monitoring)
893
- 3. Works offline (test in airplane mode)
894
- 4. No telemetry or analytics built in
895
-
896
- **Q: Is on-device AI really more secure?**
897
- A: Yes. Eliminates risks of:
898
- - Man-in-the-middle attacks on API calls
899
- - Data breaches at AI providers
900
- - Server-side logging and profiling
901
- - Third-party data access
902
-
903
- **Q: What if someone steals my phone?**
904
- A: The model itself contains no private data. Your wallet's existing security (biometrics, PIN, seed phrase) still applies. DMind-3-nano is just a tool for parsing commands.
905
-
906
- ---
907
-
908
- ## 📄 License & Usage
909
-
910
- ### Model License
911
-
912
- This model is licensed under **Apache License 2.0**, allowing commercial and non-commercial use with minimal restrictions.
913
-
914
- ### Protocol License
915
-
916
- The **SEARCH_TOKEN** and **EXECUTE_SWAP** protocol specifications are released into the **public domain** to encourage maximum adoption and standardization across the industry.
917
-
918
- **Important Compliance Notes:**
919
- - ✅ Commercial use allowed
920
- - ✅ Modification and distribution allowed
921
- - ✅ Protocol adoption requires no attribution
922
- - ⚠️ Validate all outputs before executing financial transactions
923
- - ⚠️ Comply with local cryptocurrency regulations
924
-
925
- ---
926
-
927
- ## 🙏 Acknowledgments
928
-
929
- - **Google DeepMind** for the FunctionGemma foundation model
930
- - **Hugging Face** for democratizing AI model distribution
931
- - **10,000+ community testers** who validated the protocols in production
932
- - The broader crypto and AI communities for inspiration and feedback
933
-
934
- ---
935
-
936
- ## 📚 Citation
937
-
938
- If you adopt these protocols or use DMind-3-nano in your work, please cite:
939
-
940
- ```bibtex
941
- @misc{dmind3nano2024,
942
- title={DMind-3-nano: Standardized Function Calling Protocol for Cryptocurrency Operations},
943
- author={Your Organization},
944
- year={2024},
945
- publisher={Hugging Face},
946
- howpublished={\url{https://huggingface.co/YOUR_ORG/DMind-3-nano}},
947
- note={Protocols: SEARCH\_TOKEN, EXECUTE\_SWAP}
948
- }
949
- ```
950
-
951
- ### Protocol Version
952
-
953
- **Current Version**: 1.0.0
954
- **Released**: December 2024
955
- **Status**: Production-ready
956
-
957
- ---
958
-
959
- ## 🌐 Ecosystem & Compatibility
960
-
961
- ### Protocol Adopters
962
-
963
- We're building a community of compatible implementations. Join us!
964
-
965
- | Project/Platform | Status | Implementation |
966
- |------------------|--------|----------------|
967
- | DMind-3-nano | ✅ Reference Implementation | This model |
968
- | Your Project | 🔄 Seeking adopters | [Add yours!](https://github.com/YOUR_ORG/dmind-3-nano/issues) |
969
-
970
- ### Roadmap
971
-
972
- **Privacy & Edge-First Evolution**
973
-
974
- - ✅ **v1.0**: SEARCH_TOKEN & EXECUTE_SWAP (Current) - 270M params
975
- - 🔄 **v1.1**: Multi-hop swap protocol (Q1 2025) - Mobile optimization
976
- - 📋 **v1.2**: Quantized INT8 release (Q1 2025) - ~70MB for hardware wallets
977
- - 📋 **v2.0**: DeFi lending/staking protocols (Q2 2025) - Privacy-preserving
978
- - 📋 **v2.1**: NFT operations (Q3 2025) - On-device image understanding
979
- - 📋 **v3.0**: Cross-chain bridge intent recognition (Q4 2025)
980
-
981
- **Protocol Governance Principles:**
982
- - 🔐 **Privacy-First**: Every protocol must support local inference
983
- - 🔄 **Backwards Compatible**: Old versions always work
984
- - 🌐 **Community-Driven**: Open RFC process for new protocols
985
- - 📖 **Transparent**: All decisions documented and public
986
-
987
- ---
988
-
989
- ## 📞 Connect With Us
990
-
991
- <div align="center">
992
-
993
- ### Get Involved
994
-
995
- [![Discord](https://img.shields.io/badge/Discord-Join%20Community-7289da)](https://discord.gg/YOUR_DISCORD)
996
- [![Twitter](https://img.shields.io/badge/Twitter-Follow%20Updates-1da1f2)](https://twitter.com/YOUR_HANDLE)
997
- [![GitHub](https://img.shields.io/badge/GitHub-Star%20Project-181717)](https://github.com/YOUR_ORG/dmind-3-nano)
998
-
999
- **Model Hub**: [🤗 Hugging Face](https://huggingface.co/YOUR_ORG/DMind-3-nano)
1000
- **Discussions**: [HF Discussions](https://huggingface.co/YOUR_ORG/DMind-3-nano/discussions)
1001
- **Issues**: [GitHub Issues](https://github.com/YOUR_ORG/dmind-3-nano/issues)
1002
- **Email**: your-email@example.com
1003
-
1004
- </div>
1005
-
1006
- ---
1007
-
1008
- <div align="center">
1009
-
1010
- ### 🚀 Let's Build the Privacy-First Crypto AI Ecosystem
1011
 
1012
- **🔐 Privacy. 📱 Edge-First. 🤝 Standardized.**
 
 
 
1013
 
1014
- > *"In crypto, we trust code, not clouds. Your wallet intelligence should be as private as your keys."*
1015
 
1016
- ---
1017
 
1018
- **[📥 Download Model](https://huggingface.co/YOUR_ORG/DMind-3-nano)** **[📖 Protocol Spec](https://huggingface.co/YOUR_ORG/DMind-3-nano#tool-protocols-standardized)** **[💬 Community](https://discord.gg/YOUR_DISCORD)**
 
 
1019
 
1020
- ---
1021
 
1022
- *Made with ❤️ for the open-source and crypto communities*
1023
- *Empowering users with private, on-device AI since 2024*
 
 
1024
 
1025
- </div>
 
1
+ # DMind-3-nano: Privacy-First On-Device Crypto Intent Recognition
2
+
3
+ > Inference stays on your device. Standardized function calling for wallets, DEXs, and agents. Built on `google/functiongemma-270m-it`.
4
+
5
+ **Repo purpose:** host the open-source training/eval pipeline and release artifacts. Place your exported model under `./model` before pushing to Hugging Face.
6
+
7
+ ## HF Card Metadata
8
+
9
+ ```
10
  language:
11
  - en
12
  - zh
 
27
  - standard-protocol
28
  library_name: transformers
29
  pipeline_tag: text-generation
30
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ ## Highlights
33
 
34
+ - 🔐 Privacy-first: 100% on-device intent recognition; no data leaves the device.
35
+ - 📱 Edge-optimized: 270M params; runs on phones/tablets/edge CPUs.
36
+ - 🔄 Standardized protocols: `SEARCH_TOKEN` / `EXECUTE_SWAP` with unified schemas.
37
+ - 🌐 Multi-chain: Solana, Ethereum, BSC, Base.
38
+ - 🌍 Multilingual: English + Chinese intents (Chinese samples kept in data/benchmarks).
39
+ - 🤖 Agent-native: designed for local-first wallet/agent workflows where a growing share of trading decisions and execution happen **on-device**.
40
+ - 📊 Training data: the final full fine-tune used **12,000+** samples in total; **LLM-generated data is only a subset**, and **60%+** of the data comes from **real trading scenarios**.
41
+ - 🧾 **(To our knowledge) first public vertical-domain FunctionGemma case study**: an end-to-end example of fine-tuning `google/functiongemma-270m-it` for a real wallet/DEX intent domain, including the practical training/evaluation pipeline and reproducible scripts.
42
 
43
+ ## Why This Matters for Web3 (Standardization as a Step-Change)
44
 
45
+ Web3 is composable at the protocol layer (tokens, RPCs), but still fragmented at the **intent layer**. Today every wallet, DEX, and agent framework invents its own “swap/search intent” schema and function-calling format. The result is high integration cost, brittle adapters, inconsistent safety guarantees, and poor ecosystem interoperability.
46
 
47
+ This work targets a transformative goal: **standardize wallet intents** as a small, versionable protocol between natural language and transaction builders. Concretely, DMind-3-nano enforces a minimal set of typed tools (e.g. `SEARCH_TOKEN`, `EXECUTE_SWAP`) with strict schemas and a deterministic wrapper output format.
 
48
 
49
+ What standardization unlocks:
50
 
51
+ - **Interoperability**: one protocol works across wallets/DEXs/agents; integrations become plug-and-play.
52
+ - **Safety & auditability**: tool calls are structured data—easy to validate, simulate, policy-check, and display for confirmation before signing.
53
+ - **Benchmarkability**: shared datasets and comparable evaluations across models and releases.
54
+ - **Ecosystem scaling**: new tools can be added via versioning without breaking existing clients.
55
 
56
+ In short, DMind-3-nano is not only a model—it is a proposal for a **standard protocol layer** that can make wallet intelligence as interoperable as ERC-20 made tokens.
57
 
58
+ ### The next wave: local agents executing trades
59
 
60
+ We expect a large share of future Web3 activity to be **agent-driven**: wallets will run local copilots that continuously parse user intent, monitor context, and propose/execute transactions. In that world, “cloud-only” intelligence becomes a bottleneck and a risk:
61
 
62
+ - **Privacy**: trading intent, token preferences, and behavioral signals should not be streamed to third-party servers.
63
+ - **Latency & reliability**: agents must work instantly and offline (mobile, hardware wallets, poor connectivity).
64
+ - **Security boundaries**: local agents can keep a tighter loop between intent → policy checks → simulation → user confirmation → signing.
 
 
 
 
 
 
65
 
66
+ This is why a small, high-accuracy **on-device function-calling model** is necessary infrastructure for the agent-native wallet era—and why standardizing the intent protocol matters even more when millions of agents need to speak the same language.
67
 
68
+ Equally important, this repository serves as a **public reference implementation** for applying FunctionGemma to a concrete vertical domain. By openly sharing fine-tuning details (data format, training configs, evaluation, and benchmarks), it lowers the barrier for the community to replicate, extend, and standardize on a common intent protocol.
 
 
 
 
 
 
69
 
70
+ ## Model Overview
71
 
72
  | Property | Value |
73
+ | --- | --- |
74
+ | Model | DMind-3-nano |
75
+ | Base | google/functiongemma-270m-it |
76
+ | Params | 270M |
77
+ | Context | 2048 |
78
+ | Precision | BF16 (train) |
79
+ | Best tokens | SOL, USDC, JUP, RAY, BONK, WIF, ETH, BTC, POPCAT, BOME, TRUMP |
80
+ | Chains | solana, ethereum, bsc, base |
81
+
82
+ **Experimental notice:** Highest accuracy on the token/chain set above; other assets may need further tuning. Validate outputs before transacting.
83
+
84
+ ## Repository Layout
85
+
86
+ - `model/` **(add your exported model here before publishing)**
87
+ - `src/` training/eval utilities
88
+ - `train.py` (LoRA or full fine-tune)
89
+ - `evaluate.py` (benchmark evaluation)
90
+ - `prepare_dataset.py` (SFT-ready formatting)
91
+ - `generate_benchmark.py` (100-case benchmark)
92
+ - `config.py` (tools, prompts, token maps)
93
+ - `data/` sample data
94
+ - `training_data.json` (raw; open-sourced subset for reproducibility)
95
+ - `benchmark_dataset.json` (eval set; includes Chinese test prompts by design)
96
+ - `results/evaluation_results.json` sample output
97
+ - `run_training.sh`, `requirements.txt`
98
+
99
+ ## Quick Start (Training & Eval)
100
+
101
+ Install:
102
+ ```bash
103
+ pip install -r requirements.txt
104
  ```
105
 
106
+ Train (LoRA default):
107
+ ```bash
108
+ python -m src.train \
109
+ --model_path /path/to/functiongemma-270m-it \
110
+ --dataset_path ./data/training_data.json \
111
+ --output_dir ./runs \
112
+ --bf16
113
  ```
114
+ Switch to full fine-tune: add `--no-use-lora`. Use `--use_4bit/--use_8bit` + `--gradient_checkpointing` for low memory.
115
 
116
+ Evaluate:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  ```bash
118
+ python -m src.evaluate \
119
+ --model_path ./runs/<run>/final_model \
120
+ --benchmark_path ./data/benchmark_dataset.json \
121
+ --output_path ./results/eval_$(date +%Y%m%d_%H%M%S).json
122
  ```
123
 
124
+ Data utilities:
125
+ ```bash
126
+ # Prepare SFT data
127
+ python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json
128
+ # Regenerate benchmark
129
+ python -m src.generate_benchmark --output ./data/benchmark_dataset.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ```
131
 
132
+ Note: `data/prepared_dataset.json` is a **generated artifact** (optional) and is intentionally **not committed**.
 
 
133
 
134
+ ## Tool Protocols (Standardized)
135
 
136
+ **SEARCH_TOKEN** search on-chain token info.
137
+ Params: `symbol`, `address`, `chain` (solana|ethereum|bsc|base), `keyword`.
138
 
139
+ **EXECUTE_SWAP** — execute token swap.
140
+ Params: `inputTokenSymbol`, `inputTokenCA`, `outputTokenCA`, `inputTokenAmount`, `inputTokenPercentage` (0-1), `outputTokenAmount`.
141
 
142
+ Output format:
143
+ ```
144
+ <start_function_call>call:FUNCTION_NAME{key1:val1,key2:val2}<end_function_call>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ```
146
 
147
+ Example (Chinese input retained for coverage):
 
148
  ```
149
  User: "查一下 Solana 上的 SOL"
150
  Model: <start_function_call>call:SEARCH_TOKEN{symbol:"SOL",chain:"solana"}<end_function_call>
151
  ```
152
 
153
+ ## Performance Snapshot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ - Function recognition: ~96.8% on validated set
156
+ - Argument extraction: ~94.2%
157
+ - Protocol adherence: SEARCH_TOKEN 98.1%, EXECUTE_SWAP 95.3%
158
+ - Multi-turn success: ~92.7%
159
 
160
+ Scope: tokens/chains listed in **Model Overview**; outside that set may be lower.
161
 
162
+ ## Deployment Notes
163
 
164
+ - On-device: convert to ONNX/CoreML/TFLite for mobile/hardware wallets; apply INT8 quant for ~70MB.
165
+ - CPU-only: expect sub-500ms on modern CPUs; GPUs give faster.
166
+ - Keep Chinese benchmark samples intact (they are intentional test cases).
167
 
168
+ ## License & Governance
169
 
170
+ - Code: MIT (`LICENSE`)
171
+ - Model card intent: Apache-2.0 (as in metadata above)
172
+ - Protocol specs (SEARCH_TOKEN / EXECUTE_SWAP): public domain for maximal adoption
173
+ - Contributions are welcome via issues/PRs.
174
 
175
+ Issues/PRs welcome. When publishing to Hugging Face, ensure `./model` contains your final weights/tokenizer. Replace `YOUR_ORG/DMind-3-nano`, badges, and links with your namespace before release.
data/benchmark_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
data/training_data.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88f5ee388ad60836142837fcc4790b7d53c2495763cb5627c1e4d78b9d58d9f9
3
+ size 15262245
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FunctionGemma SFT LoRA dependencies
2
+
3
+ # PyTorch (pick the build that matches your CUDA)
4
+ torch>=2.0.0
5
+ torchvision
6
+ torchaudio
7
+
8
+ # Hugging Face stack
9
+ transformers>=4.40.0
10
+ datasets>=2.18.0
11
+ accelerate>=0.27.0
12
+ tokenizers>=0.15.0
13
+
14
+ # TRL (Transformer Reinforcement Learning)
15
+ trl>=0.8.0
16
+
17
+ # PEFT (Parameter-Efficient Fine-Tuning)
18
+ peft>=0.10.0
19
+
20
+ # Quantization support (QLoRA)
21
+ bitsandbytes>=0.43.0
22
+
23
+ # Logging & monitoring
24
+ tensorboard>=2.15.0
25
+ wandb>=0.16.0
26
+
27
+ # Utilities
28
+ sentencepiece>=0.2.0
29
+ protobuf>=4.25.0
30
+ tqdm>=4.66.0
31
+
32
+ # Flash Attention (optional; install separately)
33
+ # pip install flash-attn --no-build-isolation
34
+
35
+ # Evaluation
36
+ evaluate>=0.4.0
results/evaluation_results.json ADDED
The diff for this file is too large to render. See raw diff
 
run_training.sh ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # FunctionGemma SFT LoRA quickstart
3
+
4
+ # Environment
5
+ export CUDA_VISIBLE_DEVICES=0 # e.g. "0,1,2,3" for multi-GPU
6
+ export TOKENIZERS_PARALLELISM=false
7
+
8
+ # Model path (update to your local model location)
9
+ MODEL_PATH="/path/to/your/functiongemma-270m-it"
10
+
11
+ # Dataset path
12
+ DATASET_PATH="./data/training_data.json"
13
+
14
+ # Output directory
15
+ OUTPUT_DIR="./runs"
16
+
17
+ # Run name
18
+ RUN_NAME="functiongemma-lora-$(date +%Y%m%d_%H%M%S)"
19
+
20
+ echo "========================================"
21
+ echo "FunctionGemma SFT LoRA training"
22
+ echo "========================================"
23
+ echo "Model: $MODEL_PATH"
24
+ echo "Dataset: $DATASET_PATH"
25
+ echo "Output: $OUTPUT_DIR/$RUN_NAME"
26
+ echo "========================================"
27
+
28
+ # Option 1: Standard LoRA (recommended for most GPUs)
29
+ python -m src.train \
30
+ --model_path "$MODEL_PATH" \
31
+ --dataset_path "$DATASET_PATH" \
32
+ --output_dir "$OUTPUT_DIR" \
33
+ --run_name "$RUN_NAME" \
34
+ --lora_r 16 \
35
+ --lora_alpha 32 \
36
+ --lora_dropout 0.05 \
37
+ --num_train_epochs 3 \
38
+ --per_device_train_batch_size 4 \
39
+ --gradient_accumulation_steps 4 \
40
+ --learning_rate 5e-5 \
41
+ --warmup_ratio 0.1 \
42
+ --max_seq_length 2048 \
43
+ --bf16 \
44
+ --logging_steps 10 \
45
+ --save_steps 100 \
46
+ --eval_steps 100 \
47
+ --gradient_checkpointing
48
+
49
+ # Option 2: QLoRA (for smaller GPUs, uncomment to use)
50
+ # python -m src.train \
51
+ # --model_path "$MODEL_PATH" \
52
+ # --dataset_path "$DATASET_PATH" \
53
+ # --output_dir "$OUTPUT_DIR" \
54
+ # --run_name "$RUN_NAME-qlora" \
55
+ # --lora_r 16 \
56
+ # --lora_alpha 32 \
57
+ # --lora_dropout 0.05 \
58
+ # --num_train_epochs 3 \
59
+ # --per_device_train_batch_size 8 \
60
+ # --gradient_accumulation_steps 2 \
61
+ # --learning_rate 2e-4 \
62
+ # --warmup_ratio 0.1 \
63
+ # --max_seq_length 2048 \
64
+ # --use_4bit \
65
+ # --logging_steps 10 \
66
+ # --save_steps 100 \
67
+ # --eval_steps 100 \
68
+ # --gradient_checkpointing
69
+
70
+ echo "========================================"
71
+ echo "Training finished!"
72
+ echo "Model saved to: $OUTPUT_DIR/$RUN_NAME"
73
+ echo "========================================"
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Shared configuration.
4
+
5
+ Includes:
6
+ 1. System prompt templates
7
+ 2. Token mappings (symbol -> contract address)
8
+ 3. Tool definitions
9
+ 4. Supported chains
10
+ """
11
+
12
+ # ============================================================
13
+ # Token mappings
14
+ # ============================================================
15
+
16
+ # Tokens on Solana
17
+ SOLANA_TOKENS = {
18
+ # Native
19
+ "SOL": "So11111111111111111111111111111111111111112",
20
+
21
+ # Stablecoins
22
+ "USDC": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
23
+ "USDT": "Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB",
24
+
25
+ # Major tokens
26
+ "RAY": "4k3Dyjzvzp8eMZWUXbBCjEvwSkkk59S5iCNLY3QrkX6R",
27
+ "JUP": "JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN",
28
+ "BONK": "DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263",
29
+ "WIF": "EKpQGSJtjMFqKZ9KQanSqYXRcF8fBopzLHYxdM65zcjm",
30
+ "ORCA": "orcaEKTdK7LKz57vaAYr9QeNsVEPfiu6QeMU1kektZE",
31
+ "MNGO": "MangoCzJ36AjZyKwVj3VnYU4GTonjfVEnJmvvWaxLac",
32
+ "SRM": "SRMuApVNdxXokk5GT7XD5cUUgXMBCoAz2LHeuAoKWRt",
33
+ "STEP": "StepAscQoEioFxxWGnh2sLBDFp9d8rvKz2Yp39iDpyT",
34
+ "COPE": "8HGyAAB1yoM1ttS7pXjHMa3dukTFGQggnFFH3hJZgzQh",
35
+ "FIDA": "EchesyfXePKdLtoiZSL8pBe8Myagyy8ZRqsACNCFGnvp",
36
+ "MAPS": "MAPS41MDahZ9QdKXhVa4dWB9RuyfV4XqhyAZ8XcYepb",
37
+ "OXY": "z3dn17yLaGMKffVogeFHQ9zWVcXgqgf3PQnDsNs2g6M",
38
+ "MEDIA": "ETAtLmCmsoiEEKfNrHKJ2kYy3MoABhU6NQvpSfij5tDs",
39
+ "SLND": "SLNDpmoWTVADgEdndyvWzroNL7zSi1dF9PC3xHGtPwp",
40
+ "SAMO": "7xKXtg2CW87d97TXJSDpbD5jBkheTqA83TZRuJosgAsU",
41
+
42
+ # Meme tokens
43
+ "POPCAT": "7GCihgDB8fe6KNjn2MYtkzZcRjQy3t9GHdC8uHYmW2hr",
44
+ "TRUMP": "6p6xgHyF7AeE6TZkSmFsko444wqoP15icUSqi2jfGiPN",
45
+ "MELANIA": "FUAfBo2jgks6gB4Z4LfZkqSZgzNucisEHqnNebaRxM1P",
46
+ "PEPE": "CLoUDKc4Ane7HeQcPpE3YHnznRxhMimJ4MyaUqyHFzAu",
47
+ "DOGE": "9nEqaUcb16sQ3Tn1psbkWqyhPdLmfHWjKGymREjsAgTE",
48
+ "SHIB": "CiKu4eHsVrc1eueVQeHn7qhXTcVu95gSQmBpX4utjL9z",
49
+ "FLOKI": "FLKiaqMsaLwSZLqnCpSFHktMUQsXgcrhFN3TZg1Zj2EN",
50
+ "WOJAK": "8cMrxCkREwszByAj9hzn5qKLnLjvLDU7gPYtWQgfwSAS",
51
+ "GIGA": "63LfDmNb3MQ8mw9MtZ2To9bEA2M71kZUUGq5tiJxcqj9",
52
+ "BOME": "ukHH6c7mMyiWCf1b9pnWe25TSpkDDt3H5pQZgZ74J82",
53
+ "SLERF": "7BgBvyjrZX1YKz4oh9mjb8ZScatkkwb8DzFx7LoiVkM3",
54
+ "MEW": "MEW1gQWJ3nEXg2qgERiKu7FAFj79PHvQVREQUzScPP5",
55
+ "MYRO": "HhJpBhRRn4g56VsyLuT8DL5Bv31HkXqsrahTTUCZeZg4",
56
+ "PONKE": "5z3EqYQo9HiCEs3R84RCDMu2n7anpDMxRhdK8PSWmrRC",
57
+ "MOODENG": "ED5nyyWEzpPPiWimP8vYm7sD7TD3LAt3Q3gRTWHzPJBY",
58
+
59
+ # DeFi tokens
60
+ "PYTH": "HZ1JovNiVvGrGNiiYvEozEVgZ58xaU3RKwX8eACQBCt3",
61
+ "JTO": "jtojtomepa8beP8AuQc6eXt5FriJwfFMwQx2v2f9mCL",
62
+ "W": "85VBFQZC9TZkfaptBWjvUw7YbZjy52A6mjtPGjstQAmQ",
63
+ "RENDER": "rndrizKT3MK1iimdxRdWabcF7Zg7AR5T4nud4EkHBof",
64
+ "HNT": "hntyVP6YFm1Hg25TN9WGLqM12b8TQmcknKrdu1oxWux",
65
+ "INJ": "6McPRfPV6bY1e9hLxWyG54W9i9Epq75QBvXg2oetBVTB",
66
+
67
+ # Wrapped tokens
68
+ "WETH": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", # same as WETH
69
+ "WBTC": "3NZ9JMVBmGAqocybic2c7LQCJScmgsAZ6vQqTDzcqmJh", # same as WBTC
70
+ "ETH": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", # same as WETH
71
+ "BTC": "3NZ9JMVBmGAqocybic2c7LQCJScmgsAZ6vQqTDzcqmJh", # same as WBTC
72
+ }
73
+
74
+ # Tokens on Ethereum
75
+ ETHEREUM_TOKENS = {
76
+ "ETH": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
77
+ "WETH": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2",
78
+ "USDC": "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48",
79
+ "USDT": "0xdAC17F958D2ee523a2206206994597C13D831ec7",
80
+ "DAI": "0x6B175474E89094C44Da98b954EescdeCB5BE3830",
81
+ "WBTC": "0x2260FAC5E5542a773Aa44fBCfeDf7C193bc2C599",
82
+ "UNI": "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984",
83
+ "LINK": "0x514910771AF9Ca656af840dff83E8264EcF986CA",
84
+ "AAVE": "0x7Fc66500c84A76Ad7e9c93437bFc5Ac33E2DDaE9",
85
+ "PEPE": "0x6982508145454Ce325dDbE47a25d4ec3d2311933",
86
+ "SHIB": "0x95aD61b0a150d79219dCF64E1E6Cc01f0B64C4cE",
87
+ "TRUMP": "0x576e2BeD8F7b46D34016198911Cdf9886f78bea7",
88
+ }
89
+
90
+ # Tokens on BSC
91
+ BSC_TOKENS = {
92
+ "BNB": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
93
+ "WBNB": "0xbb4CdB9CBd36B01bD1cBaEBF2De08d9173bc095c",
94
+ "USDC": "0x8AC76a51cc950d9822D68b83fE1Ad97B32Cd580d",
95
+ "USDT": "0x55d398326f99059fF775485246999027B3197955",
96
+ "BUSD": "0xe9e7CEA3DedcA5984780Bafc599bD69ADd087D56",
97
+ "CAKE": "0x0E09FaBB73Bd3Ade0a17ECC321fD13a19e81cE82",
98
+ "DOGE": "0xbA2aE424d960c26247Dd6c32edC70B295c744C43",
99
+ "SHIB": "0x2859e4544C4bB03966803b044A93563Bd2D0DD4D",
100
+ "FLOKI": "0xfb5B838b6cfEEdC2873aB27866079AC55363D37E",
101
+ }
102
+
103
+ # Tokens on Base
104
+ BASE_TOKENS = {
105
+ "ETH": "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE",
106
+ "WETH": "0x4200000000000000000000000000000000000006",
107
+ "USDC": "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913",
108
+ "DAI": "0x50c5725949A6F0c72E6C4a641F24049A917DB0Cb",
109
+ "BRETT": "0x532f27101965dd16442E59d40670FaF5eBB142E4",
110
+ "DEGEN": "0x4ed4E862860beD51a9570b96d89aF5E1B0Efefed",
111
+ "TOSHI": "0xAC1Bd2486aAf3B5C0fc3Fd868558b082a531B2B4",
112
+ }
113
+
114
+ # Token mappings for all chains
115
+ ALL_TOKENS = {
116
+ "solana": SOLANA_TOKENS,
117
+ "ethereum": ETHEREUM_TOKENS,
118
+ "bsc": BSC_TOKENS,
119
+ "base": BASE_TOKENS,
120
+ }
121
+
122
+ # Supported chains
123
+ SUPPORTED_CHAINS = ["solana", "ethereum", "bsc", "base"]
124
+
125
+ # Default chain
126
+ DEFAULT_CHAIN = "solana"
127
+
128
+
129
+ # ============================================================
130
+ # System prompt templates
131
+ # ============================================================
132
+
133
+ def get_system_prompt(chain: str = "solana") -> str:
134
+ """
135
+ Build a system prompt that contains the token mapping for a chain.
136
+
137
+ Args:
138
+ chain: blockchain name
139
+
140
+ Returns:
141
+ system prompt string
142
+ """
143
+ tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
144
+
145
+ # Build the token mapping string
146
+ token_list = []
147
+ for symbol, address in tokens.items():
148
+ token_list.append(f" - {symbol}: {address}")
149
+ token_mapping_str = "\n".join(token_list)
150
+
151
+ system_prompt = f"""You are a blockchain trading assistant that helps users search for tokens and execute swaps.
152
+
153
+ ## Available Tools
154
+
155
+ 1. **SEARCH_TOKEN**: Search for token information on the blockchain.
156
+ - Parameters: symbol, address, chain, keyword
157
+ - Use when user wants to find/query/search token information
158
+
159
+ 2. **EXECUTE_SWAP**: Execute a token swap on the blockchain.
160
+ - Parameters: inputTokenCA, outputTokenCA, inputTokenAmount, inputTokenPercentage
161
+ - Use when user wants to buy/sell/swap/convert tokens
162
+
163
+ ## Token Address Mapping ({chain.upper()})
164
+
165
+ {token_mapping_str}
166
+
167
+ ## Important Rules
168
+
169
+ 1. When user mentions a token symbol (like SOL, USDC, RAY), use the corresponding contract address from the mapping above.
170
+ 2. For buying tokens: inputTokenCA is what user spends (usually SOL/ETH), outputTokenCA is what user receives.
171
+ 3. For selling tokens: inputTokenCA is what user sells, outputTokenCA is usually SOL/ETH.
172
+ 4. If user specifies an amount, use inputTokenAmount. If user specifies a percentage (like "all", "50%", "half"), use inputTokenPercentage.
173
+ 5. inputTokenPercentage should be a decimal between 0 and 1 (e.g., 0.5 for 50%, 1.0 for 100%).
174
+ 6. If the user's request is unclear or missing required information, ask for clarification instead of making assumptions.
175
+ 7. Do NOT call any function if the request is unrelated to token search or swap (e.g., weather, greetings, jokes).
176
+
177
+ ## Examples
178
+
179
+ - "Buy 2 SOL of RAY" → EXECUTE_SWAP with inputTokenCA=SOL address, outputTokenCA=RAY address, inputTokenAmount="2"
180
+ - "Sell all my JUP" → EXECUTE_SWAP with inputTokenCA=JUP address, outputTokenCA=SOL address, inputTokenPercentage=1.0
181
+ - "Search for BONK" → SEARCH_TOKEN with symbol="BONK", chain="{chain}"
182
+ """
183
+
184
+ return system_prompt
185
+
186
+
187
+ # Shorter system prompt (used for training)
188
+ def get_system_prompt_short(chain: str = "solana") -> str:
189
+ """Build the concise system prompt."""
190
+ tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
191
+
192
+ # Only include common tokens
193
+ common_tokens = ["SOL", "USDC", "USDT", "RAY", "JUP", "BONK", "WIF", "TRUMP", "POPCAT", "PEPE", "ETH", "WETH", "BTC", "WBTC"]
194
+ token_list = []
195
+ for symbol in common_tokens:
196
+ if symbol in tokens:
197
+ token_list.append(f"{symbol}:{tokens[symbol]}")
198
+ token_mapping_str = ", ".join(token_list)
199
+
200
+ system_prompt = f"""You are a blockchain trading assistant. Use SEARCH_TOKEN to find tokens and EXECUTE_SWAP to trade tokens.
201
+
202
+ Token addresses ({chain}): {token_mapping_str}
203
+
204
+ Rules:
205
+ - For buy: inputTokenCA=payment token, outputTokenCA=target token
206
+ - For sell: inputTokenCA=token to sell, outputTokenCA=SOL/ETH
207
+ - Use inputTokenAmount for specific amounts, inputTokenPercentage (0-1) for percentages
208
+ - Ask for clarification if request is unclear
209
+ - Do NOT call functions for unrelated requests (weather, greetings, etc.)
210
+ """
211
+
212
+ return system_prompt
213
+
214
+
215
+ # ============================================================
216
+ # Tool definitions
217
+ # ============================================================
218
+
219
+ TOOLS = [
220
+ {
221
+ "type": "function",
222
+ "function": {
223
+ "name": "SEARCH_TOKEN",
224
+ "description": "Search for token information on the blockchain. Use this when user wants to find, query, or get information about a token.",
225
+ "parameters": {
226
+ "type": "object",
227
+ "properties": {
228
+ "symbol": {
229
+ "type": ["string", "null"],
230
+ "description": "Symbol of the token (e.g., SOL, USDC, RAY)"
231
+ },
232
+ "address": {
233
+ "type": ["string", "null"],
234
+ "description": "Contract address of the token"
235
+ },
236
+ "chain": {
237
+ "type": "string",
238
+ "enum": ["solana", "ethereum", "bsc", "base"],
239
+ "description": "Blockchain to search on"
240
+ },
241
+ "keyword": {
242
+ "type": ["string", "null"],
243
+ "description": "Keyword to search for the token"
244
+ }
245
+ },
246
+ "required": []
247
+ }
248
+ }
249
+ },
250
+ {
251
+ "type": "function",
252
+ "function": {
253
+ "name": "EXECUTE_SWAP",
254
+ "description": "Execute a token swap on the blockchain. Use this when user wants to buy, sell, swap, or convert tokens.",
255
+ "parameters": {
256
+ "type": "object",
257
+ "properties": {
258
+ "inputTokenCA": {
259
+ "type": ["string", "null"],
260
+ "description": "Contract address of the token to spend/sell"
261
+ },
262
+ "outputTokenCA": {
263
+ "type": ["string", "null"],
264
+ "description": "Contract address of the token to receive/buy"
265
+ },
266
+ "inputTokenAmount": {
267
+ "type": ["string", "null"],
268
+ "description": "Exact amount of input token to swap"
269
+ },
270
+ "inputTokenPercentage": {
271
+ "type": ["number", "null"],
272
+ "description": "Percentage of input token balance to swap (0-1, e.g., 0.5 for 50%)"
273
+ }
274
+ },
275
+ "required": ["inputTokenCA", "outputTokenCA"]
276
+ }
277
+ }
278
+ }
279
+ ]
280
+
281
+
282
+ # ============================================================
283
+ # Helper functions
284
+ # ============================================================
285
+
286
+ def get_token_address(symbol: str, chain: str = "solana") -> str:
287
+ """Get token address by symbol."""
288
+ tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
289
+ return tokens.get(symbol.upper(), None)
290
+
291
+
292
+ def get_native_token(chain: str = "solana") -> tuple:
293
+ """Get the native token (symbol, address) for a chain."""
294
+ native_tokens = {
295
+ "solana": ("SOL", SOLANA_TOKENS["SOL"]),
296
+ "ethereum": ("ETH", ETHEREUM_TOKENS["ETH"]),
297
+ "bsc": ("BNB", BSC_TOKENS["BNB"]),
298
+ "base": ("ETH", BASE_TOKENS["ETH"]),
299
+ }
300
+ return native_tokens.get(chain, ("SOL", SOLANA_TOKENS["SOL"]))
301
+
302
+
303
+ def get_all_token_symbols(chain: str = "solana") -> list:
304
+ """Return all token symbols on a chain."""
305
+ tokens = ALL_TOKENS.get(chain, SOLANA_TOKENS)
306
+ return list(tokens.keys())
307
+
308
+
309
+ if __name__ == "__main__":
310
+ # Quick self-test
311
+ print("=== System Prompt (Full) ===")
312
+ print(get_system_prompt("solana")[:500] + "...")
313
+ print("\n=== System Prompt (Short) ===")
314
+ print(get_system_prompt_short("solana"))
315
+ print("\n=== Token Address ===")
316
+ print(f"RAY: {get_token_address('RAY', 'solana')}")
317
+ print(f"TRUMP: {get_token_address('TRUMP', 'solana')}")
src/evaluate.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FunctionGemma evaluation script (v2).
4
+
5
+ Uses a unified system prompt for evaluation.
6
+
7
+ Usage:
8
+ python -m src.evaluate --model_path ./runs/<run>/final_model --benchmark_path ./data/benchmark_dataset.json
9
+ """
10
+
11
+ import os
12
+ import re
13
+ import sys
14
+ import json
15
+ import argparse
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional, Tuple
19
+ from datetime import datetime
20
+ from concurrent.futures import ThreadPoolExecutor, as_completed
21
+ from threading import Lock
22
+
23
+ import torch
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from peft import PeftModel
26
+ from tqdm import tqdm
27
+
28
+ # Import config
29
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
30
+ if str(PROJECT_ROOT) not in sys.path:
31
+ sys.path.insert(0, str(PROJECT_ROOT))
32
+
33
+ DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json"
34
+ DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results"
35
+
36
+ from src.config import ( # noqa: E402
37
+ get_system_prompt, get_system_prompt_short, TOOLS,
38
+ SOLANA_TOKENS, get_token_address
39
+ )
40
+
41
+ # Logging
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format='%(asctime)s - %(levelname)s - %(message)s'
45
+ )
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ def load_model(
50
+ model_path: str,
51
+ lora_path: Optional[str] = None,
52
+ device: str = "auto",
53
+ load_in_8bit: bool = False,
54
+ load_in_4bit: bool = False,
55
+ ):
56
+ """Load model and tokenizer."""
57
+ logger.info(f"Loading model: {model_path}")
58
+
59
+ kwargs = {
60
+ "device_map": device,
61
+ "trust_remote_code": True,
62
+ }
63
+
64
+ if load_in_8bit:
65
+ kwargs["load_in_8bit"] = True
66
+ elif load_in_4bit:
67
+ from transformers import BitsAndBytesConfig
68
+ kwargs["quantization_config"] = BitsAndBytesConfig(
69
+ load_in_4bit=True,
70
+ bnb_4bit_compute_dtype=torch.bfloat16,
71
+ bnb_4bit_use_double_quant=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ )
74
+ else:
75
+ kwargs["torch_dtype"] = torch.bfloat16
76
+
77
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
78
+ if tokenizer.pad_token is None:
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
82
+
83
+ if lora_path:
84
+ logger.info(f"Loading LoRA adapter: {lora_path}")
85
+ model = PeftModel.from_pretrained(model, lora_path)
86
+
87
+ model.eval()
88
+ return model, tokenizer
89
+
90
+
91
+ def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]:
92
+ """
93
+ Parse FunctionGemma formatted output.
94
+
95
+ Format: <start_function_call>call:FUNC_NAME{key:<escape>value<escape>,...}<end_function_call>
96
+ """
97
+ # full match
98
+ pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>'
99
+ match = re.search(pattern, response)
100
+
101
+ if not match:
102
+ # partial match (truncated)
103
+ pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}'
104
+ match = re.search(pattern, response)
105
+
106
+ if not match:
107
+ # match function name only
108
+ pattern = r'<start_function_call>call:(\w+)'
109
+ match = re.search(pattern, response)
110
+ if match:
111
+ return match.group(1), {}
112
+
113
+ # fallback: look for function names
114
+ for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]:
115
+ if func in response:
116
+ return func, {}
117
+
118
+ return None, None
119
+
120
+ func_name = match.group(1)
121
+ params_str = match.group(2) if len(match.groups()) > 1 else ""
122
+
123
+ # parse arguments
124
+ args = parse_params_string(params_str)
125
+
126
+ return func_name, args
127
+
128
+
129
+ def parse_params_string(params_str: str) -> Dict:
130
+ """Parse parameter string."""
131
+ args = {}
132
+ if not params_str:
133
+ return args
134
+
135
+ # pattern: key:<escape>value<escape> or key:value
136
+ param_pattern = r'(\w+):(?:<escape>([^<]*)<escape>|([^,}]+))'
137
+
138
+ for match in re.finditer(param_pattern, params_str):
139
+ key = match.group(1)
140
+ value = match.group(2) if match.group(2) is not None else match.group(3)
141
+
142
+ if value is None:
143
+ continue
144
+
145
+ value = value.strip()
146
+
147
+ # handle percentage
148
+ if value.endswith('%'):
149
+ try:
150
+ args[key] = float(value[:-1]) / 100
151
+ continue
152
+ except ValueError:
153
+ pass
154
+
155
+ # attempt numeric conversion
156
+ try:
157
+ if '.' in value:
158
+ args[key] = float(value)
159
+ else:
160
+ args[key] = int(value)
161
+ except ValueError:
162
+ args[key] = value
163
+
164
+ return args
165
+
166
+
167
+ def is_rejection_response(response: str) -> bool:
168
+ """Check if the response is a rejection/clarification."""
169
+ # no function call markers
170
+ if '<start_function_call>' not in response:
171
+ return True
172
+
173
+ # check clarification/rejection keywords (keep Chinese variants for CN prompts)
174
+ rejection_keywords = [
175
+ "please specify", "could you", "what token", "which token",
176
+ "请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币",
177
+ "sorry", "can't", "cannot", "unable", "抱歉", "无法",
178
+ "more information", "more details", "更多信息",
179
+ ]
180
+
181
+ response_lower = response.lower()
182
+ for keyword in rejection_keywords:
183
+ if keyword.lower() in response_lower:
184
+ return True
185
+
186
+ return False
187
+
188
+
189
+ def format_messages_for_model(
190
+ messages: List[Dict],
191
+ tokenizer,
192
+ tools: List[Dict] = None,
193
+ ) -> str:
194
+ """Format messages into the model chat template."""
195
+ if hasattr(tokenizer, 'apply_chat_template'):
196
+ try:
197
+ return tokenizer.apply_chat_template(
198
+ messages,
199
+ tools=tools,
200
+ tokenize=False,
201
+ add_generation_prompt=True,
202
+ )
203
+ except Exception:
204
+ pass
205
+
206
+ # Manual formatting fallback
207
+ formatted = ""
208
+ for msg in messages:
209
+ role = msg["role"]
210
+ content = msg["content"]
211
+
212
+ if role == "system":
213
+ formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n"
214
+ elif role == "user":
215
+ formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
216
+ elif role == "assistant":
217
+ formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
218
+
219
+ formatted += "<start_of_turn>model\n"
220
+ return formatted
221
+
222
+
223
+ def generate_response(
224
+ model,
225
+ tokenizer,
226
+ prompt: str,
227
+ system_prompt: str,
228
+ max_new_tokens: int = 256,
229
+ ) -> str:
230
+ """Generate model response."""
231
+ messages = [
232
+ {"role": "system", "content": system_prompt},
233
+ {"role": "user", "content": prompt},
234
+ ]
235
+
236
+ input_text = format_messages_for_model(messages, tokenizer, TOOLS)
237
+ inputs = tokenizer(input_text, return_tensors="pt")
238
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
239
+
240
+ with torch.no_grad():
241
+ outputs = model.generate(
242
+ **inputs,
243
+ max_new_tokens=max_new_tokens,
244
+ temperature=0.1,
245
+ do_sample=True,
246
+ pad_token_id=tokenizer.pad_token_id,
247
+ eos_token_id=tokenizer.eos_token_id,
248
+ )
249
+
250
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
251
+ response = response.replace("<end_of_turn>", "").strip()
252
+
253
+ return response
254
+
255
+
256
+ def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]:
257
+ """Compare expected vs actual arguments."""
258
+ if not expected:
259
+ return 1.0 if not actual else 0.0, []
260
+
261
+ if not actual:
262
+ return 0.0, ["No arguments extracted"]
263
+
264
+ errors = []
265
+ total_keys = set(expected.keys()) | set(actual.keys())
266
+
267
+ if not total_keys:
268
+ return 1.0, []
269
+
270
+ matched = 0
271
+
272
+ for key in expected.keys():
273
+ exp_val = expected.get(key)
274
+ act_val = actual.get(key)
275
+
276
+ if exp_val is None:
277
+ continue
278
+
279
+ if act_val is None:
280
+ errors.append(f"Missing key: {key}")
281
+ continue
282
+
283
+ # Compare values
284
+ if str(exp_val) == str(act_val):
285
+ matched += 1
286
+ elif isinstance(exp_val, str) and isinstance(act_val, str):
287
+ # Partial match (contract address prefix)
288
+ if exp_val[:10] == act_val[:10]:
289
+ matched += 0.5
290
+ errors.append(f"Partial match for {key}")
291
+ else:
292
+ errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}")
293
+ elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)):
294
+ if abs(float(exp_val) - float(act_val)) < 0.01:
295
+ matched += 1
296
+ else:
297
+ errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}")
298
+ else:
299
+ errors.append(f"Type mismatch for {key}")
300
+
301
+ # Check extra keys
302
+ for key in actual.keys():
303
+ if key not in expected:
304
+ errors.append(f"Extra key: {key}")
305
+
306
+ score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0
307
+ return score, errors
308
+
309
+
310
+ def process_single_sample(
311
+ sample: Dict,
312
+ idx: int,
313
+ model,
314
+ tokenizer,
315
+ system_prompt: str,
316
+ ) -> Dict:
317
+ """Process one sample and return evaluation result."""
318
+ sample_id = sample.get("id", idx + 1)
319
+ category = sample.get("category", "unknown")
320
+ user_input = sample["input"]
321
+ expected_func = sample["expected"]["function_name"]
322
+ expected_args = sample["expected"].get("arguments", {})
323
+
324
+ # Extract user message
325
+ if isinstance(user_input, dict) and "messages" in user_input:
326
+ prompt = ""
327
+ for msg in user_input["messages"]:
328
+ if msg.get("role") == "user":
329
+ prompt = msg.get("content", "")
330
+ break
331
+ else:
332
+ prompt = str(user_input)
333
+
334
+ # Generate response
335
+ response = generate_response(model, tokenizer, prompt, system_prompt)
336
+
337
+ # Parse response
338
+ actual_func, actual_args = parse_functiongemma_output(response)
339
+ is_rejection = is_rejection_response(response)
340
+
341
+ # Evaluate
342
+ func_correct = False
343
+ args_correct = False
344
+ exact_match = False
345
+ arg_score = 0.0
346
+ error_msg = None
347
+ rejection_correct = False
348
+
349
+ if expected_func is None:
350
+ # Expecting rejection
351
+ func_correct = is_rejection or actual_func is None
352
+ args_correct = func_correct
353
+ exact_match = func_correct
354
+ arg_score = 1.0 if func_correct else 0.0
355
+ rejection_correct = func_correct
356
+
357
+ if not func_correct:
358
+ error_msg = f"Expected rejection, got {actual_func}"
359
+ else:
360
+ # Expecting a function call
361
+ func_correct = actual_func == expected_func
362
+
363
+ if func_correct:
364
+ # Compare arguments
365
+ arg_score, arg_errors = compare_arguments(expected_args, actual_args or {})
366
+ args_correct = arg_score >= 0.99
367
+ exact_match = args_correct
368
+
369
+ if not args_correct:
370
+ error_msg = "; ".join(arg_errors)
371
+ else:
372
+ error_msg = f"Expected {expected_func}, got {actual_func}"
373
+
374
+ # Return result
375
+ result = {
376
+ "sample_id": sample_id,
377
+ "category": category,
378
+ "expected_func": expected_func,
379
+ "actual_func": actual_func,
380
+ "func_correct": func_correct,
381
+ "args_correct": args_correct,
382
+ "exact_match": exact_match,
383
+ "rejection_correct": rejection_correct,
384
+ "arg_score": arg_score,
385
+ "error_msg": error_msg,
386
+ "user_input": user_input,
387
+ "expected_args": expected_args,
388
+ "actual_args": actual_args,
389
+ "response": response,
390
+ }
391
+
392
+ return result
393
+
394
+
395
+ def evaluate_benchmark(
396
+ model,
397
+ tokenizer,
398
+ benchmark: List[Dict],
399
+ chain: str = "solana",
400
+ verbose: bool = False,
401
+ num_workers: int = 1,
402
+ ) -> Dict:
403
+ """Evaluate the benchmark (supports concurrency)."""
404
+ system_prompt = get_system_prompt_short(chain)
405
+
406
+ results = {
407
+ "total": len(benchmark),
408
+ "function_correct": 0,
409
+ "arguments_correct": 0,
410
+ "exact_match": 0,
411
+ "rejection_correct": 0,
412
+ "total_arg_score": 0.0,
413
+ "by_category": {},
414
+ "by_function": {},
415
+ "errors": [],
416
+ }
417
+
418
+ # Protect result updates with a lock
419
+ results_lock = Lock()
420
+
421
+ # Concurrent processing
422
+ if num_workers > 1:
423
+ logger.info(f"Evaluating with {num_workers} worker threads")
424
+
425
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
426
+ # Submit tasks
427
+ futures = {
428
+ executor.submit(
429
+ process_single_sample,
430
+ sample, i, model, tokenizer, system_prompt
431
+ ): i for i, sample in enumerate(benchmark)
432
+ }
433
+
434
+ # Progress bar
435
+ with tqdm(total=len(benchmark), desc="Evaluation") as pbar:
436
+ for future in as_completed(futures):
437
+ result = future.result()
438
+
439
+ # Update results (locked)
440
+ with results_lock:
441
+ _update_results(results, result, verbose)
442
+
443
+ pbar.update(1)
444
+ else:
445
+ # Serial path
446
+ logger.info("Evaluating with a single thread")
447
+ for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")):
448
+ result = process_single_sample(sample, i, model, tokenizer, system_prompt)
449
+ _update_results(results, result, verbose)
450
+
451
+ return results
452
+
453
+
454
+ def _update_results(results: Dict, result: Dict, verbose: bool):
455
+ """Update aggregated evaluation results."""
456
+ sample_id = result["sample_id"]
457
+ category = result["category"]
458
+ expected_func = result["expected_func"]
459
+ actual_func = result["actual_func"]
460
+ func_correct = result["func_correct"]
461
+ args_correct = result["args_correct"]
462
+ exact_match = result["exact_match"]
463
+ rejection_correct = result["rejection_correct"]
464
+ arg_score = result["arg_score"]
465
+ error_msg = result["error_msg"]
466
+
467
+ # Overall stats
468
+ if func_correct:
469
+ results["function_correct"] += 1
470
+ if args_correct:
471
+ results["arguments_correct"] += 1
472
+ if exact_match:
473
+ results["exact_match"] += 1
474
+ if rejection_correct:
475
+ results["rejection_correct"] += 1
476
+ results["total_arg_score"] += arg_score
477
+
478
+ # By category
479
+ if category not in results["by_category"]:
480
+ results["by_category"][category] = {
481
+ "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0
482
+ }
483
+ results["by_category"][category]["total"] += 1
484
+ if func_correct:
485
+ results["by_category"][category]["func_correct"] += 1
486
+ if exact_match:
487
+ results["by_category"][category]["exact_match"] += 1
488
+ results["by_category"][category]["arg_score"] += arg_score
489
+
490
+ # By function
491
+ func_key = expected_func or "None"
492
+ if func_key not in results["by_function"]:
493
+ results["by_function"][func_key] = {
494
+ "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0
495
+ }
496
+ results["by_function"][func_key]["total"] += 1
497
+ if func_correct:
498
+ results["by_function"][func_key]["func_correct"] += 1
499
+ if exact_match:
500
+ results["by_function"][func_key]["exact_match"] += 1
501
+ results["by_function"][func_key]["arg_score"] += arg_score
502
+
503
+ # Record errors
504
+ if error_msg and len(results["errors"]) < 10:
505
+ results["errors"].append({
506
+ "id": sample_id,
507
+ "category": category,
508
+ "input": result["user_input"],
509
+ "expected_func": expected_func,
510
+ "actual_func": actual_func,
511
+ "expected_args": result["expected_args"],
512
+ "actual_args": result["actual_args"],
513
+ "error": error_msg,
514
+ "response": result["response"][:200],
515
+ })
516
+
517
+ if verbose:
518
+ status = "✓" if exact_match else "✗"
519
+ # Extract user message preview for logs
520
+ user_input = result["user_input"]
521
+ if isinstance(user_input, dict):
522
+ user_msg = ""
523
+ if "messages" in user_input:
524
+ for msg in user_input["messages"]:
525
+ if msg.get("role") == "user":
526
+ user_msg = msg.get("content", "")
527
+ break
528
+ input_preview = user_msg[:50] if user_msg else str(user_input)[:50]
529
+ else:
530
+ input_preview = str(user_input)[:50]
531
+ logger.info(f"[{sample_id}] {status} {category}: {input_preview}...")
532
+
533
+
534
+ def print_report(results: Dict):
535
+ """Print evaluation report."""
536
+ total = results["total"]
537
+
538
+ print("\n" + "=" * 70)
539
+ print("FunctionGemma Evaluation Report")
540
+ print("=" * 70)
541
+ print(f"\nTotal samples: {total}")
542
+
543
+ print("\n" + "-" * 70)
544
+ print("Overall metrics")
545
+ print("-" * 70)
546
+
547
+ func_acc = results["function_correct"] / total * 100 if total > 0 else 0
548
+ arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0
549
+ exact_acc = results["exact_match"] / total * 100 if total > 0 else 0
550
+ avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0
551
+
552
+ # Rejection accuracy
553
+ rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f))
554
+ rejection_total = results["by_function"].get("None", {}).get("total", 0)
555
+ rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0
556
+
557
+ print(f"Function selection accuracy: {func_acc:.2f}%")
558
+ print(f"Argument accuracy: {arg_acc:.2f}%")
559
+ print(f"Exact match accuracy: {exact_acc:.2f}%")
560
+ print(f"Average argument score: {avg_arg_score:.2f}%")
561
+ print(f"Rejection accuracy: {rejection_acc:.2f}%")
562
+
563
+ print("\n" + "-" * 70)
564
+ print("By function")
565
+ print("-" * 70)
566
+
567
+ for func, stats in sorted(results["by_function"].items()):
568
+ func_total = stats["total"]
569
+ func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0
570
+ func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0
571
+ func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0
572
+
573
+ print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | "
574
+ f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%")
575
+
576
+ if results["errors"]:
577
+ print("\n" + "-" * 70)
578
+ print("Error samples")
579
+ print("-" * 70)
580
+
581
+ for err in results["errors"][:5]:
582
+ print(f"\nID: {err['id']} | category: {err['category']}")
583
+ print(f"Input: {err['input']}")
584
+ print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}")
585
+ print(f"Error: {err['error']}")
586
+
587
+ print("\n" + "=" * 70)
588
+
589
+
590
+ def main():
591
+ parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)")
592
+ parser.add_argument("--model_path", type=str, required=True, help="Model path")
593
+ parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path")
594
+ parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path")
595
+ parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)")
596
+ parser.add_argument("--chain", type=str, default="solana", help="Chain name")
597
+ parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization")
598
+ parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization")
599
+ parser.add_argument("--verbose", action="store_true", help="Verbose logging")
600
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)")
601
+ args = parser.parse_args()
602
+
603
+ # Load model
604
+ model, tokenizer = load_model(
605
+ args.model_path,
606
+ lora_path=args.lora_path,
607
+ load_in_8bit=args.load_in_8bit,
608
+ load_in_4bit=args.load_in_4bit,
609
+ )
610
+
611
+ # Load benchmark
612
+ benchmark_path = Path(args.benchmark_path)
613
+ logger.info(f"Loading benchmark: {benchmark_path}")
614
+ with open(benchmark_path, 'r', encoding='utf-8') as f:
615
+ benchmark = json.load(f)
616
+
617
+ logger.info(f"Benchmark samples: {len(benchmark)}")
618
+
619
+ # Evaluate
620
+ logger.info("Starting evaluation...")
621
+ results = evaluate_benchmark(
622
+ model, tokenizer, benchmark,
623
+ chain=args.chain,
624
+ verbose=args.verbose,
625
+ num_workers=args.num_workers,
626
+ )
627
+
628
+ # Print report
629
+ print_report(results)
630
+
631
+ # Save results
632
+ output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
633
+ output_path.parent.mkdir(parents=True, exist_ok=True)
634
+
635
+ with open(output_path, 'w', encoding='utf-8') as f:
636
+ json.dump(results, f, ensure_ascii=False, indent=2)
637
+ logger.info(f"Evaluation saved to: {output_path}")
638
+
639
+
640
+ if __name__ == "__main__":
641
+ main()
src/generate_benchmark.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate the FunctionGemma evaluation benchmark.
4
+
5
+ Creates 100 high-quality samples to assess function-calling accuracy across:
6
+ - SEARCH_TOKEN calls
7
+ - EXECUTE_SWAP calls
8
+ - Incomplete requests (should ask back)
9
+ - Irrelevant requests (should refuse)
10
+ """
11
+
12
+ import json
13
+ import random
14
+ import argparse
15
+ from pathlib import Path
16
+ from typing import Dict, List, Any, Optional
17
+
18
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
19
+ DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json"
20
+
21
+ # Token info
22
+ TOKENS = {
23
+ "SOL": {"ca": "So11111111111111111111111111111111111111112", "chain": "solana"},
24
+ "USDC": {"ca": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", "chain": "solana"},
25
+ "JUP": {"ca": "JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN", "chain": "solana"},
26
+ "RAY": {"ca": "4k3Dyjzvzp8eMZWUXbBCjEvwSkkk59S5iCNLY3QrkX6R", "chain": "solana"},
27
+ "BONK": {"ca": "DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263", "chain": "solana"},
28
+ "WIF": {"ca": "EKpQGSJtjMFqKZ9KQanSqYXRcF8fBopzLHYxdM65zcjm", "chain": "solana"},
29
+ "ETH": {"ca": "7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs", "chain": "solana"},
30
+ "BTC": {"ca": "9n4nbM75f5Ui33ZbPYXn59EwSgE8CGsHtAeTH5YFeJ9E", "chain": "solana"},
31
+ "POPCAT": {"ca": "7GCihgDB8fe6KNjn2MYtkzZcRjQy3t9GHdC8uHYmW2hr", "chain": "solana"},
32
+ "TRUMP": {"ca": "6p6xgHyF7AeE6TZkSmFsko444wqoP15icUSqi2jfGiPN", "chain": "solana"},
33
+ }
34
+
35
+ CHAINS = ["solana", "ethereum", "bsc", "base"]
36
+
37
+ # Tool definitions
38
+ TOOLS = [
39
+ {
40
+ "type": "function",
41
+ "function": {
42
+ "name": "SEARCH_TOKEN",
43
+ "description": "search token onchain",
44
+ "parameters": {
45
+ "type": "object",
46
+ "properties": {
47
+ "symbol": {"type": ["string", "null"], "description": "Symbol of the token"},
48
+ "address": {"type": ["string", "null"], "description": "Contract address of the token"},
49
+ "chain": {"type": "string", "enum": ["solana", "ethereum", "bsc", "base"], "description": "supported chains"},
50
+ "keyword": {"type": ["string", "null"], "description": "keyword to search for the token"}
51
+ },
52
+ "required": []
53
+ }
54
+ }
55
+ },
56
+ {
57
+ "type": "function",
58
+ "function": {
59
+ "name": "EXECUTE_SWAP",
60
+ "description": "Swap tokens on the Solana blockchain. When the user specifies 'buy <token>', the default input token is SOL. When the user specifies 'sell <token>', the default output token is SOL.",
61
+ "parameters": {
62
+ "type": "object",
63
+ "properties": {
64
+ "inputTokenSymbol": {"type": ["string", "null"], "description": "Symbol of the token to sell."},
65
+ "inputTokenCA": {"type": ["string", "null"], "description": "Contract address of the token to sell."},
66
+ "outputTokenCA": {"type": ["string", "null"], "description": "Contract address of the token to buy."},
67
+ "inputTokenAmount": {"type": ["string", "null"], "description": "Exact amount of the input token to swap."},
68
+ "inputTokenPercentage": {"type": ["number", "null"], "description": "Percentage of the input token balance to swap."},
69
+ "outputTokenAmount": {"type": ["string", "null"], "description": "Expected amount of the output token to receive."}
70
+ },
71
+ "required": ["inputTokenCA", "outputTokenCA", "inputTokenAmount", "inputTokenPercentage"]
72
+ }
73
+ }
74
+ }
75
+ ]
76
+
77
+
78
+ def create_benchmark_item(
79
+ user_input: str,
80
+ expected_function: Optional[str],
81
+ expected_args: Optional[Dict] = None,
82
+ category: str = "function_call",
83
+ description: str = ""
84
+ ) -> Dict:
85
+ """Create one benchmark sample."""
86
+ return {
87
+ "id": None, # assigned later
88
+ "category": category,
89
+ "description": description,
90
+ "input": {
91
+ "messages": [
92
+ {"role": "developer", "content": "You are a model that can do function calling with the following functions"},
93
+ {"role": "user", "content": user_input}
94
+ ],
95
+ "tools": TOOLS
96
+ },
97
+ "expected": {
98
+ "function_name": expected_function,
99
+ "arguments": expected_args
100
+ }
101
+ }
102
+
103
+
104
+ def generate_search_token_benchmarks() -> List[Dict]:
105
+ """Generate SEARCH_TOKEN cases."""
106
+ benchmarks = []
107
+
108
+ # 1) search by symbol (English)
109
+ test_cases = [
110
+ ("Search for BONK token", "BONK", "solana", None, None),
111
+ ("Find WIF on solana", "WIF", "solana", None, None),
112
+ ("Look up JUP token", "JUP", "solana", None, None),
113
+ ("Search ETH on ethereum", "ETH", "ethereum", None, None),
114
+ ("Find USDC token on base", "USDC", "base", None, None),
115
+ ]
116
+
117
+ for query, symbol, chain, address, keyword in test_cases:
118
+ expected_args = {"symbol": symbol, "chain": chain}
119
+ if address:
120
+ expected_args["address"] = address
121
+ if keyword:
122
+ expected_args["keyword"] = keyword
123
+ benchmarks.append(create_benchmark_item(
124
+ query, "SEARCH_TOKEN", expected_args,
125
+ "search_by_symbol", f"Search {symbol} by symbol"
126
+ ))
127
+
128
+ # 2) search by symbol (Chinese)
129
+ cn_cases = [
130
+ ("帮我搜索 BONK 代币", "BONK", "solana"),
131
+ ("查一下 WIF 这个币", "WIF", "solana"),
132
+ ("找一下 JUP 代币信息", "JUP", "solana"),
133
+ ("搜索 RAY 代币", "RAY", "solana"),
134
+ ("查询 POPCAT 代币", "POPCAT", "solana"),
135
+ ]
136
+
137
+ for query, symbol, chain in cn_cases:
138
+ benchmarks.append(create_benchmark_item(
139
+ query, "SEARCH_TOKEN", {"symbol": symbol, "chain": chain},
140
+ "search_by_symbol_cn", f"Search {symbol} by symbol (Chinese)"
141
+ ))
142
+
143
+ # 3) search by address
144
+ for token, info in list(TOKENS.items())[:5]:
145
+ query = f"Search token at address {info['ca']}"
146
+ benchmarks.append(create_benchmark_item(
147
+ query, "SEARCH_TOKEN", {"address": info['ca'], "chain": info['chain']},
148
+ "search_by_address", f"Search {token} by address"
149
+ ))
150
+
151
+ # 4) search by keyword
152
+ keyword_cases = [
153
+ ("Search for dog themed tokens", "dog", "solana"),
154
+ ("Find meme coins", "meme", "solana"),
155
+ ("Look for cat tokens on base", "cat", "base"),
156
+ ]
157
+
158
+ for query, keyword, chain in keyword_cases:
159
+ benchmarks.append(create_benchmark_item(
160
+ query, "SEARCH_TOKEN", {"keyword": keyword, "chain": chain},
161
+ "search_by_keyword", f"Search by keyword: {keyword}"
162
+ ))
163
+
164
+ return benchmarks
165
+
166
+
167
+ def generate_execute_swap_benchmarks() -> List[Dict]:
168
+ """Generate EXECUTE_SWAP cases."""
169
+ benchmarks = []
170
+
171
+ # 1) buy token (fixed amount)
172
+ buy_cases = [
173
+ ("Buy 1 SOL worth of BONK", "SOL", "BONK", "1", None),
174
+ ("Purchase 5 SOL of WIF", "SOL", "WIF", "5", None),
175
+ ("Buy 10 USDC worth of JUP", "USDC", "JUP", "10", None),
176
+ ("I want to buy 2 SOL of RAY", "SOL", "RAY", "2", None),
177
+ ("Get me 0.5 SOL of POPCAT", "SOL", "POPCAT", "0.5", None),
178
+ ]
179
+
180
+ for query, input_token, output_token, amount, percentage in buy_cases:
181
+ input_ca = TOKENS[input_token]["ca"]
182
+ output_ca = TOKENS[output_token]["ca"]
183
+ benchmarks.append(create_benchmark_item(
184
+ query, "EXECUTE_SWAP",
185
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
186
+ "buy_with_amount", f"Buy {output_token} with {amount} {input_token}"
187
+ ))
188
+
189
+ # 2) buy token (percentage)
190
+ buy_pct_cases = [
191
+ ("Buy BONK with 50% of my SOL", "SOL", "BONK", None, 0.5),
192
+ ("Use 30% of my USDC to buy WIF", "USDC", "WIF", None, 0.3),
193
+ ("Spend 100% of my SOL on JUP", "SOL", "JUP", None, 1.0),
194
+ ("Put 25% of my ETH into RAY", "ETH", "RAY", None, 0.25),
195
+ ("Use half of my BTC to get BONK", "BTC", "BONK", None, 0.5),
196
+ ]
197
+
198
+ for query, input_token, output_token, amount, percentage in buy_pct_cases:
199
+ input_ca = TOKENS[input_token]["ca"]
200
+ output_ca = TOKENS[output_token]["ca"]
201
+ benchmarks.append(create_benchmark_item(
202
+ query, "EXECUTE_SWAP",
203
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
204
+ "buy_with_percentage", f"Buy {output_token} with {int(percentage*100)}% {input_token}"
205
+ ))
206
+
207
+ # 3) sell token (fixed amount)
208
+ sell_cases = [
209
+ ("Sell 1000 BONK", "BONK", "SOL", "1000", None),
210
+ ("Sell 500 WIF for SOL", "WIF", "SOL", "500", None),
211
+ ("Convert 100 JUP to SOL", "JUP", "SOL", "100", None),
212
+ ("Dump 2000 RAY", "RAY", "SOL", "2000", None),
213
+ ("Sell 50 USDC", "USDC", "SOL", "50", None),
214
+ ]
215
+
216
+ for query, input_token, output_token, amount, percentage in sell_cases:
217
+ input_ca = TOKENS[input_token]["ca"]
218
+ output_ca = TOKENS[output_token]["ca"]
219
+ benchmarks.append(create_benchmark_item(
220
+ query, "EXECUTE_SWAP",
221
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
222
+ "sell_with_amount", f"Sell {amount} {input_token}"
223
+ ))
224
+
225
+ # 4) sell token (percentage)
226
+ sell_pct_cases = [
227
+ ("Sell 50% of my BONK", "BONK", "SOL", None, 0.5),
228
+ ("Dump all my WIF", "WIF", "SOL", None, 1.0),
229
+ ("Sell 30% of my JUP holdings", "JUP", "SOL", None, 0.3),
230
+ ("Get rid of 75% of my RAY", "RAY", "SOL", None, 0.75),
231
+ ("Sell a quarter of my USDC", "USDC", "SOL", None, 0.25),
232
+ ]
233
+
234
+ for query, input_token, output_token, amount, percentage in sell_pct_cases:
235
+ input_ca = TOKENS[input_token]["ca"]
236
+ output_ca = TOKENS[output_token]["ca"]
237
+ benchmarks.append(create_benchmark_item(
238
+ query, "EXECUTE_SWAP",
239
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
240
+ "sell_with_percentage", f"Sell {int(percentage*100)}% {input_token}"
241
+ ))
242
+
243
+ # 5) Chinese buy/sell requests (content kept)
244
+ cn_swap_cases = [
245
+ ("用 1 个 SOL 买 BONK", "SOL", "BONK", "1", None),
246
+ ("把 50% 的 USDC 换成 WIF", "USDC", "WIF", None, 0.5),
247
+ ("卖掉 1000 个 BONK", "BONK", "SOL", "1000", None),
248
+ ("把所有 JUP 都卖了", "JUP", "SOL", None, 1.0),
249
+ ("用 2 SOL 购买 RAY", "SOL", "RAY", "2", None),
250
+ ("出售 30% 的 WIF", "WIF", "SOL", None, 0.3),
251
+ ("买入 5 SOL 的 POPCAT", "SOL", "POPCAT", "5", None),
252
+ ("清仓 ETH", "ETH", "SOL", None, 1.0),
253
+ ]
254
+
255
+ for query, input_token, output_token, amount, percentage in cn_swap_cases:
256
+ input_ca = TOKENS[input_token]["ca"]
257
+ output_ca = TOKENS[output_token]["ca"]
258
+ benchmarks.append(create_benchmark_item(
259
+ query, "EXECUTE_SWAP",
260
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
261
+ "swap_chinese", f"Swap request in Chinese"
262
+ ))
263
+
264
+ # 6) swap between tokens
265
+ swap_cases = [
266
+ ("Swap 100 USDC for BONK", "USDC", "BONK", "100", None),
267
+ ("Exchange 50 JUP for WIF", "JUP", "WIF", "50", None),
268
+ ("Convert all my ETH to USDC", "ETH", "USDC", None, 1.0),
269
+ ]
270
+
271
+ for query, input_token, output_token, amount, percentage in swap_cases:
272
+ input_ca = TOKENS[input_token]["ca"]
273
+ output_ca = TOKENS[output_token]["ca"]
274
+ benchmarks.append(create_benchmark_item(
275
+ query, "EXECUTE_SWAP",
276
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": amount, "inputTokenPercentage": percentage},
277
+ "token_to_token", f"Swap {input_token} to {output_token}"
278
+ ))
279
+
280
+ return benchmarks
281
+
282
+
283
+ def generate_incomplete_benchmarks() -> List[Dict]:
284
+ """Generate incomplete requests (should ask clarification)."""
285
+ benchmarks = []
286
+
287
+ incomplete_cases = [
288
+ ("I want to buy some tokens", "incomplete_no_token", "Missing token name"),
289
+ ("Sell my holdings", "incomplete_no_token", "Missing which token to sell"),
290
+ ("Search for a token", "incomplete_no_info", "Missing token info"),
291
+ ("Buy something", "incomplete_vague", "Too vague"),
292
+ ("我想买币", "incomplete_cn", "Missing token (Chinese)"),
293
+ ("帮我卖掉", "incomplete_cn", "Missing token and amount (Chinese)"),
294
+ ("Swap tokens", "incomplete_swap", "Missing swap details"),
295
+ ("I want to trade", "incomplete_trade", "Missing trade details"),
296
+ ]
297
+
298
+ for query, category, description in incomplete_cases:
299
+ benchmarks.append(create_benchmark_item(
300
+ query, None, None, category, description
301
+ ))
302
+
303
+ return benchmarks
304
+
305
+
306
+ def generate_irrelevant_benchmarks() -> List[Dict]:
307
+ """Generate irrelevant requests (should not call any function)."""
308
+ benchmarks = []
309
+
310
+ irrelevant_cases = [
311
+ ("What's the weather today?", "irrelevant_weather", "Weather query"),
312
+ ("Tell me a joke", "irrelevant_joke", "Joke request"),
313
+ ("What time is it?", "irrelevant_time", "Time query"),
314
+ ("Who is the president?", "irrelevant_general", "General knowledge"),
315
+ ("今天天气怎么样?", "irrelevant_cn", "Weather (Chinese)"),
316
+ ("给我讲个笑话", "irrelevant_cn", "Joke (Chinese)"),
317
+ ("Hello, how are you?", "irrelevant_greeting", "Greeting"),
318
+ ("What is Bitcoin?", "irrelevant_info", "Info request (no action)"),
319
+ ]
320
+
321
+ for query, category, description in irrelevant_cases:
322
+ benchmarks.append(create_benchmark_item(
323
+ query, None, None, category, description
324
+ ))
325
+
326
+ return benchmarks
327
+
328
+
329
+ def generate_benchmark_dataset(output_path: str = str(DEFAULT_BENCHMARK_PATH)):
330
+ """Generate the full benchmark dataset."""
331
+
332
+ print("=" * 60)
333
+ print("Generating FunctionGemma benchmark dataset")
334
+ print("=" * 60)
335
+
336
+ # Collect all cases
337
+ all_benchmarks = []
338
+
339
+ # SEARCH_TOKEN cases
340
+ search_benchmarks = generate_search_token_benchmarks()
341
+ print(f"SEARCH_TOKEN cases: {len(search_benchmarks)}")
342
+ all_benchmarks.extend(search_benchmarks)
343
+
344
+ # EXECUTE_SWAP cases
345
+ swap_benchmarks = generate_execute_swap_benchmarks()
346
+ print(f"EXECUTE_SWAP cases: {len(swap_benchmarks)}")
347
+ all_benchmarks.extend(swap_benchmarks)
348
+
349
+ # Incomplete requests
350
+ incomplete_benchmarks = generate_incomplete_benchmarks()
351
+ print(f"Incomplete request cases: {len(incomplete_benchmarks)}")
352
+ all_benchmarks.extend(incomplete_benchmarks)
353
+
354
+ # Irrelevant requests
355
+ irrelevant_benchmarks = generate_irrelevant_benchmarks()
356
+ print(f"Irrelevant request cases: {len(irrelevant_benchmarks)}")
357
+ all_benchmarks.extend(irrelevant_benchmarks)
358
+
359
+ # Pad to 100 if needed
360
+ while len(all_benchmarks) < 100:
361
+ # Add a few variants
362
+ extra_cases = [
363
+ ("Buy 3 SOL of TRUMP", "SOL", "TRUMP", "3", None, "EXECUTE_SWAP"),
364
+ ("Search for TRUMP token", "TRUMP", "solana", None, None, "SEARCH_TOKEN"),
365
+ ]
366
+ for case in extra_cases:
367
+ if len(all_benchmarks) >= 100:
368
+ break
369
+ if case[5] == "EXECUTE_SWAP":
370
+ input_ca = TOKENS[case[1]]["ca"]
371
+ output_ca = TOKENS[case[2]]["ca"]
372
+ all_benchmarks.append(create_benchmark_item(
373
+ case[0], "EXECUTE_SWAP",
374
+ {"inputTokenCA": input_ca, "outputTokenCA": output_ca, "inputTokenAmount": case[3], "inputTokenPercentage": case[4]},
375
+ "extra", "Extra test case"
376
+ ))
377
+ else:
378
+ all_benchmarks.append(create_benchmark_item(
379
+ case[0], "SEARCH_TOKEN",
380
+ {"symbol": case[1], "chain": case[2]},
381
+ "extra", "Extra test case"
382
+ ))
383
+
384
+ # Limit to 100
385
+ all_benchmarks = all_benchmarks[:100]
386
+
387
+ # Assign ids
388
+ for i, item in enumerate(all_benchmarks):
389
+ item["id"] = i + 1
390
+
391
+ # Shuffle
392
+ random.seed(42)
393
+ random.shuffle(all_benchmarks)
394
+
395
+ # Re-assign ids
396
+ for i, item in enumerate(all_benchmarks):
397
+ item["id"] = i + 1
398
+
399
+ print(f"\nTotal: {len(all_benchmarks)} cases")
400
+
401
+ # Category stats
402
+ categories = {}
403
+ for item in all_benchmarks:
404
+ cat = item["category"]
405
+ categories[cat] = categories.get(cat, 0) + 1
406
+
407
+ print("\nCategory distribution:")
408
+ for cat, count in sorted(categories.items()):
409
+ print(f" - {cat}: {count}")
410
+
411
+ # Function stats
412
+ func_counts = {"SEARCH_TOKEN": 0, "EXECUTE_SWAP": 0, "None": 0}
413
+ for item in all_benchmarks:
414
+ func = item["expected"]["function_name"]
415
+ if func:
416
+ func_counts[func] = func_counts.get(func, 0) + 1
417
+ else:
418
+ func_counts["None"] += 1
419
+
420
+ print("\nFunction distribution:")
421
+ for func, count in func_counts.items():
422
+ print(f" - {func}: {count}")
423
+
424
+ # Save
425
+ with open(output_path, 'w', encoding='utf-8') as f:
426
+ json.dump(all_benchmarks, f, ensure_ascii=False, indent=2)
427
+
428
+ print(f"\nBenchmark saved to: {output_path}")
429
+
430
+ # Show examples
431
+ print("\n" + "=" * 60)
432
+ print("Examples:")
433
+ print("=" * 60)
434
+
435
+ for i, item in enumerate(all_benchmarks[:3]):
436
+ print(f"\n--- Example {i+1} ---")
437
+ print(f"ID: {item['id']}")
438
+ print(f"Category: {item['category']}")
439
+ print(f"Input: {item['input']['messages'][1]['content']}")
440
+ print(f"Expected function: {item['expected']['function_name']}")
441
+ if item['expected']['arguments']:
442
+ print(f"Expected args: {json.dumps(item['expected']['arguments'], ensure_ascii=False)}")
443
+
444
+ return all_benchmarks
445
+
446
+
447
+ def main():
448
+ parser = argparse.ArgumentParser(description="Generate FunctionGemma benchmark dataset")
449
+ parser.add_argument("--output", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Output file path")
450
+ args = parser.parse_args()
451
+
452
+ output_path = Path(args.output)
453
+ output_path.parent.mkdir(parents=True, exist_ok=True)
454
+
455
+ generate_benchmark_dataset(str(output_path))
456
+
457
+
458
+ if __name__ == "__main__":
459
+ main()
src/prepare_dataset.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Data preprocessing script.
4
+
5
+ Convert the generated dataset into a format directly consumable by SFTTrainer.
6
+ FunctionGemma expects a specific chat template structure.
7
+
8
+ Usage:
9
+ python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json
10
+ """
11
+
12
+ import json
13
+ import argparse
14
+ from pathlib import Path
15
+ from typing import List, Dict, Any
16
+
17
+
18
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
19
+ DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json"
20
+ DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json"
21
+
22
+
23
+ def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str:
24
+ """Convert tool_calls into plain text (FunctionGemma format)."""
25
+ if not tool_calls:
26
+ return ""
27
+
28
+ result_parts = []
29
+ for tc in tool_calls:
30
+ func = tc.get("function", {})
31
+ name = func.get("name", "")
32
+ args = func.get("arguments", {})
33
+
34
+ # FunctionGemma format: functionName(arguments)
35
+ args_str = json.dumps(args, ensure_ascii=False)
36
+ result_parts.append(f"{name}({args_str})")
37
+
38
+ return "\n".join(result_parts)
39
+
40
+
41
+ def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]:
42
+ """
43
+ Convert message format for SFTTrainer.
44
+
45
+ Input:
46
+ [
47
+ {"role": "developer", "content": "..."},
48
+ {"role": "user", "content": "..."},
49
+ {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
50
+ ]
51
+
52
+ Output:
53
+ [
54
+ {"role": "system", "content": "..."}, # developer -> system
55
+ {"role": "user", "content": "..."},
56
+ {"role": "assistant", "content": "..."} # tool_calls flattened to text
57
+ ]
58
+ """
59
+ converted = []
60
+
61
+ # Build tools description
62
+ tools_description = ""
63
+ if tools:
64
+ tools_desc_parts = []
65
+ for tool in tools:
66
+ if tool.get("type") == "function":
67
+ func = tool.get("function", {})
68
+ name = func.get("name", "")
69
+ desc = func.get("description", "")
70
+ params = func.get("parameters", {})
71
+ tools_desc_parts.append(f"- {name}: {desc}")
72
+ if tools_desc_parts:
73
+ tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts)
74
+
75
+ for msg in messages:
76
+ role = msg.get("role", "")
77
+
78
+ if role == "developer":
79
+ # developer -> system
80
+ content = msg.get("content", "")
81
+ if tools_description:
82
+ content = content + tools_description
83
+ converted.append({
84
+ "role": "system",
85
+ "content": content
86
+ })
87
+
88
+ elif role == "user":
89
+ converted.append({
90
+ "role": "user",
91
+ "content": msg.get("content", "")
92
+ })
93
+
94
+ elif role == "assistant":
95
+ if "tool_calls" in msg:
96
+ # Convert tool_calls to text
97
+ tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"])
98
+ converted.append({
99
+ "role": "assistant",
100
+ "content": tool_calls_text
101
+ })
102
+ else:
103
+ converted.append({
104
+ "role": "assistant",
105
+ "content": msg.get("content", "")
106
+ })
107
+
108
+ elif role == "tool":
109
+ # Tool response
110
+ converted.append({
111
+ "role": "tool",
112
+ "content": msg.get("content", "")
113
+ })
114
+
115
+ return converted
116
+
117
+
118
+ def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"):
119
+ """
120
+ Prepare dataset.
121
+
122
+ format_type:
123
+ - "messages": output {"messages": [...]}
124
+ - "text": output {"text": "..."} (flattened text)
125
+ """
126
+ print(f"Loading dataset: {input_path}")
127
+
128
+ with open(input_path, 'r', encoding='utf-8') as f:
129
+ data = json.load(f)
130
+
131
+ print(f"Raw samples: {len(data)}")
132
+
133
+ prepared_data = []
134
+
135
+ for i, item in enumerate(data):
136
+ messages = item.get("messages", [])
137
+ tools = item.get("tools", [])
138
+
139
+ # Convert messages
140
+ converted_messages = convert_messages_for_sft(messages, tools)
141
+
142
+ if format_type == "messages":
143
+ prepared_data.append({
144
+ "messages": converted_messages
145
+ })
146
+ elif format_type == "text":
147
+ # Convert to plain text
148
+ text_parts = []
149
+ for msg in converted_messages:
150
+ role = msg["role"]
151
+ content = msg["content"]
152
+ if role == "system":
153
+ text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>")
154
+ elif role == "user":
155
+ text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
156
+ elif role == "assistant":
157
+ text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
158
+
159
+ prepared_data.append({
160
+ "text": "\n".join(text_parts)
161
+ })
162
+
163
+ print(f"Processed samples: {len(prepared_data)}")
164
+
165
+ # Save
166
+ with open(output_path, 'w', encoding='utf-8') as f:
167
+ json.dump(prepared_data, f, ensure_ascii=False, indent=2)
168
+
169
+ print(f"Saved to: {output_path}")
170
+
171
+ # Show example
172
+ print("\n" + "=" * 60)
173
+ print("Example:")
174
+ print("=" * 60)
175
+
176
+ if format_type == "messages":
177
+ example = prepared_data[0]
178
+ for msg in example["messages"]:
179
+ print(f"\n[{msg['role']}]")
180
+ print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"])
181
+ else:
182
+ print(prepared_data[0]["text"][:500] + "...")
183
+
184
+ return prepared_data
185
+
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser(description="Dataset preparation")
189
+ parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path")
190
+ parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path")
191
+ parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format")
192
+
193
+ args = parser.parse_args()
194
+
195
+ prepare_dataset(args.input, args.output, args.format)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()
src/train.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FunctionGemma SFT fine-tuning script.
4
+
5
+ Runs TRL SFTTrainer for FunctionGemma with two modes:
6
+ 1) LoRA (recommended): faster, lower memory, less overfit
7
+ 2) Full-parameter: higher cost, maximal capacity
8
+
9
+ Usage:
10
+ # LoRA (default)
11
+ python -m src.train \
12
+ --model_path /path/to/model \
13
+ --dataset_path ./data/training_data.json \
14
+ --bf16
15
+
16
+ # Full-parameter
17
+ python -m src.train \
18
+ --model_path /path/to/model \
19
+ --dataset_path ./data/training_data.json \
20
+ --no-use-lora \
21
+ --bf16
22
+ """
23
+
24
+ import os
25
+ import json
26
+ import argparse
27
+ import logging
28
+ from datetime import datetime
29
+ from pathlib import Path
30
+ from typing import Optional
31
+
32
+ import torch
33
+ from datasets import Dataset, load_dataset
34
+ from transformers import (
35
+ AutoModelForCausalLM,
36
+ AutoTokenizer,
37
+ TrainingArguments,
38
+ BitsAndBytesConfig,
39
+ )
40
+ from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
41
+ from trl import SFTTrainer, SFTConfig
42
+
43
+ # Paths and logging
44
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
45
+ DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json"
46
+ DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs"
47
+
48
+ logging.basicConfig(
49
+ level=logging.INFO,
50
+ format='%(asctime)s - %(levelname)s - %(message)s'
51
+ )
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ def parse_args():
56
+ """Parse CLI arguments."""
57
+ parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)")
58
+
59
+ # Model
60
+ parser.add_argument(
61
+ "--model_path",
62
+ type=str,
63
+ default="google/functiongemma-270m-it",
64
+ help="Model path or HF model id"
65
+ )
66
+ parser.add_argument(
67
+ "--tokenizer_path",
68
+ type=str,
69
+ default=None,
70
+ help="Tokenizer path (defaults to model_path)"
71
+ )
72
+
73
+ # Dataset
74
+ parser.add_argument(
75
+ "--dataset_path",
76
+ type=str,
77
+ default=str(DEFAULT_DATA_PATH),
78
+ help="Training dataset path"
79
+ )
80
+ parser.add_argument(
81
+ "--val_split",
82
+ type=float,
83
+ default=0.1,
84
+ help="Validation split ratio"
85
+ )
86
+
87
+ # Output
88
+ parser.add_argument(
89
+ "--output_dir",
90
+ type=str,
91
+ default=str(DEFAULT_OUTPUT_DIR),
92
+ help="Root output directory"
93
+ )
94
+ parser.add_argument(
95
+ "--run_name",
96
+ type=str,
97
+ default=None,
98
+ help="Run name for logging and saving"
99
+ )
100
+
101
+ # Fine-tuning mode
102
+ parser.add_argument(
103
+ "--use_lora",
104
+ action="store_true",
105
+ default=True,
106
+ help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune"
107
+ )
108
+ parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune")
109
+
110
+ # LoRA (only when use_lora=True)
111
+ parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank")
112
+ parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
113
+ parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")
114
+ parser.add_argument(
115
+ "--target_modules",
116
+ type=str,
117
+ nargs="+",
118
+ default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
119
+ help="Target modules for LoRA"
120
+ )
121
+
122
+ # Training (aligned with FunctionGemma guidance)
123
+ parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)")
124
+ parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)")
125
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device")
126
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size")
127
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps")
128
+ parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
129
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
130
+ parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)")
131
+ parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)")
132
+ parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)")
133
+
134
+ # Precision & optimization
135
+ parser.add_argument("--bf16", action="store_true", help="Use BF16")
136
+ parser.add_argument("--fp16", action="store_true", help="Use FP16")
137
+ parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)")
138
+ parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant")
139
+ parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2")
140
+ parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing")
141
+
142
+ # Logging & saving
143
+ parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps")
144
+ parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps")
145
+ parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps")
146
+ parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep")
147
+
148
+ # Misc
149
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
150
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
151
+ parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub")
152
+ parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id")
153
+
154
+ return parser.parse_args()
155
+
156
+
157
+ def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1):
158
+ """Load and normalize dataset structure for SFT."""
159
+ logger.info(f"Loading dataset: {dataset_path}")
160
+
161
+ # Load JSON dataset
162
+ with open(dataset_path, 'r', encoding='utf-8') as f:
163
+ data = json.load(f)
164
+
165
+ logger.info(f"Dataset size: {len(data)} samples")
166
+
167
+ # Normalize nested structures:
168
+ # if an item has input.messages/tools, lift them to top-level
169
+ processed_data = []
170
+ for idx, item in enumerate(data):
171
+ if 'input' in item and 'messages' in item['input']:
172
+ # Deep copy messages to avoid mutating original
173
+ messages = json.loads(json.dumps(item['input']['messages']))
174
+
175
+ # Fix tool_calls formatting if present
176
+ for msg in messages:
177
+ if 'tool_calls' in msg and msg['tool_calls']:
178
+ for tc in msg['tool_calls']:
179
+ if 'function' in tc and 'arguments' in tc['function']:
180
+ args = tc['function']['arguments']
181
+ # ensure arguments is a string
182
+ if not isinstance(args, str):
183
+ tc['function']['arguments'] = json.dumps(args)
184
+
185
+ # Convert expected field into assistant response if present
186
+ if 'expected' in item and item['expected']:
187
+ expected = item['expected']
188
+ # If last message is not assistant, append one
189
+ if messages[-1]['role'] != 'assistant':
190
+ # Decide between function call or refusal
191
+ function_name = expected.get('function_name')
192
+ arguments = expected.get('arguments')
193
+ response = expected.get('response', '')
194
+
195
+ if function_name is not None and arguments is not None:
196
+ # Case 1: function call -> add tool_calls
197
+ arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
198
+
199
+ assistant_msg = {
200
+ "role": "assistant",
201
+ "content": None,
202
+ "tool_calls": [{
203
+ "id": f"call_{hash(function_name + arguments_str) % 1000000}", # generate unique id
204
+ "type": "function",
205
+ "function": {
206
+ "name": function_name,
207
+ "arguments": arguments_str
208
+ }
209
+ }]
210
+ }
211
+ messages.append(assistant_msg)
212
+ logger.debug(f"Added assistant tool_calls: {function_name}")
213
+ elif function_name is None and arguments is None and response:
214
+ # Case 2: refusal -> plain text response
215
+ assistant_msg = {
216
+ "role": "assistant",
217
+ "content": response
218
+ }
219
+ messages.append(assistant_msg)
220
+ logger.debug(f"Added assistant refusal response: {response[:50]}")
221
+ else:
222
+ logger.warning(f"Unknown expected format: {expected}")
223
+
224
+ processed_item = {
225
+ 'messages': messages
226
+ }
227
+
228
+ # include tools if present
229
+ if 'tools' in item['input']:
230
+ processed_item['tools'] = item['input']['tools']
231
+
232
+ # preserve id
233
+ if 'id' in item:
234
+ processed_item['id'] = item['id']
235
+
236
+ # Final check: tool_calls arguments must be strings
237
+ for msg in processed_item['messages']:
238
+ if 'tool_calls' in msg and msg['tool_calls']:
239
+ for tc in msg['tool_calls']:
240
+ if 'function' in tc and 'arguments' in tc['function']:
241
+ if not isinstance(tc['function']['arguments'], str):
242
+ logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}")
243
+ tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
244
+
245
+ processed_data.append(processed_item)
246
+
247
+ elif 'messages' in item:
248
+ # Already proper format, just normalize tool_calls
249
+ messages = json.loads(json.dumps(item['messages']))
250
+ for msg in messages:
251
+ if 'tool_calls' in msg and msg['tool_calls']:
252
+ for tc in msg['tool_calls']:
253
+ if 'function' in tc and 'arguments' in tc['function']:
254
+ if not isinstance(tc['function']['arguments'], str):
255
+ tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
256
+ item_copy = dict(item)
257
+ item_copy['messages'] = messages
258
+ processed_data.append(item_copy)
259
+ else:
260
+ logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}")
261
+
262
+ logger.info(f"Processed dataset size: {len(processed_data)}")
263
+
264
+ # Validate format
265
+ tool_calls_count = 0
266
+ for item in processed_data:
267
+ for msg in item['messages']:
268
+ if 'tool_calls' in msg and msg['tool_calls']:
269
+ tool_calls_count += 1
270
+ for tc in msg['tool_calls']:
271
+ if 'function' in tc and 'arguments' in tc['function']:
272
+ if not isinstance(tc['function']['arguments'], str):
273
+ logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}")
274
+ logger.info(f"Messages containing tool_calls: {tool_calls_count}")
275
+
276
+ # Convert to Hugging Face Dataset
277
+ dataset = Dataset.from_list(processed_data)
278
+
279
+ # Split train/val
280
+ if val_split > 0:
281
+ dataset = dataset.train_test_split(test_size=val_split, seed=42)
282
+ train_dataset = dataset['train']
283
+ eval_dataset = dataset['test']
284
+ logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
285
+ else:
286
+ train_dataset = dataset
287
+ eval_dataset = None
288
+ logger.info(f"Train: {len(train_dataset)}, no eval split")
289
+
290
+ return train_dataset, eval_dataset
291
+
292
+
293
+ def get_quantization_config(use_4bit: bool, use_8bit: bool):
294
+ """Build quantization config if requested."""
295
+ if use_4bit:
296
+ logger.info("Using 4-bit quantization (QLoRA)")
297
+ return BitsAndBytesConfig(
298
+ load_in_4bit=True,
299
+ bnb_4bit_quant_type="nf4",
300
+ bnb_4bit_compute_dtype=torch.bfloat16,
301
+ bnb_4bit_use_double_quant=True,
302
+ )
303
+ elif use_8bit:
304
+ logger.info("Using 8-bit quantization")
305
+ return BitsAndBytesConfig(
306
+ load_in_8bit=True,
307
+ )
308
+ return None
309
+
310
+
311
+ def load_model_and_tokenizer(args):
312
+ """Load model and tokenizer."""
313
+ logger.info(f"Loading model: {args.model_path}")
314
+
315
+ tokenizer_path = args.tokenizer_path or args.model_path
316
+
317
+ # Load tokenizer
318
+ tokenizer = AutoTokenizer.from_pretrained(
319
+ tokenizer_path,
320
+ trust_remote_code=True,
321
+ padding_side="right",
322
+ )
323
+
324
+ # Ensure pad token exists
325
+ if tokenizer.pad_token is None:
326
+ tokenizer.pad_token = tokenizer.eos_token
327
+ tokenizer.pad_token_id = tokenizer.eos_token_id
328
+
329
+ # Quantization config
330
+ quantization_config = get_quantization_config(args.use_4bit, args.use_8bit)
331
+
332
+ # Model kwargs
333
+ model_kwargs = {
334
+ "trust_remote_code": True,
335
+ "device_map": "auto",
336
+ }
337
+
338
+ if quantization_config:
339
+ model_kwargs["quantization_config"] = quantization_config
340
+
341
+ # Precision
342
+ if args.bf16 and not (args.use_4bit or args.use_8bit):
343
+ model_kwargs["torch_dtype"] = torch.bfloat16
344
+ elif args.fp16 and not (args.use_4bit or args.use_8bit):
345
+ model_kwargs["torch_dtype"] = torch.float16
346
+
347
+ # Flash Attention
348
+ if args.use_flash_attention:
349
+ model_kwargs["attn_implementation"] = "flash_attention_2"
350
+ logger.info("Using Flash Attention 2")
351
+
352
+ # Load model
353
+ model = AutoModelForCausalLM.from_pretrained(
354
+ args.model_path,
355
+ **model_kwargs
356
+ )
357
+
358
+ # Prepare for k-bit training when quantized
359
+ if args.use_4bit or args.use_8bit:
360
+ model = prepare_model_for_kbit_training(model)
361
+
362
+ # Gradient checkpointing
363
+ if args.gradient_checkpointing:
364
+ model.gradient_checkpointing_enable()
365
+ logger.info("Enabled gradient checkpointing")
366
+
367
+ logger.info(f"Model parameters: {model.num_parameters():,}")
368
+
369
+ return model, tokenizer
370
+
371
+
372
+ def get_lora_config(args):
373
+ """Build LoRA config."""
374
+ logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}")
375
+ logger.info(f"Target modules: {args.target_modules}")
376
+
377
+ return LoraConfig(
378
+ r=args.lora_r,
379
+ lora_alpha=args.lora_alpha,
380
+ lora_dropout=args.lora_dropout,
381
+ target_modules=args.target_modules,
382
+ bias="none",
383
+ task_type=TaskType.CAUSAL_LM,
384
+ )
385
+
386
+
387
+ def formatting_func(example):
388
+ """
389
+ Format function: pass data through for SFTTrainer.
390
+
391
+ Dataset format:
392
+ {
393
+ "messages": [
394
+ {"role": "developer", "content": "..."},
395
+ {"role": "user", "content": "..."},
396
+ {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
397
+ ],
398
+ "tools": [...]
399
+ }
400
+ """
401
+ # Return as-is; SFTTrainer applies chat template
402
+ return example
403
+
404
+
405
+ def main():
406
+ args = parse_args()
407
+
408
+ # Set run name
409
+ if args.run_name is None:
410
+ args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
411
+
412
+ # Create output directory
413
+ output_dir = os.path.join(args.output_dir, args.run_name)
414
+ os.makedirs(output_dir, exist_ok=True)
415
+
416
+ logger.info("=" * 60)
417
+ logger.info("FunctionGemma SFT LoRA training")
418
+ logger.info("=" * 60)
419
+ logger.info(f"Output dir: {output_dir}")
420
+
421
+ # Save config
422
+ config_path = os.path.join(output_dir, "training_config.json")
423
+ with open(config_path, 'w') as f:
424
+ json.dump(vars(args), f, indent=2)
425
+ logger.info(f"Config saved to: {config_path}")
426
+
427
+ # Load dataset
428
+ train_dataset, eval_dataset = load_and_prepare_dataset(
429
+ args.dataset_path,
430
+ args.val_split
431
+ )
432
+
433
+ # Load model + tokenizer
434
+ model, tokenizer = load_model_and_tokenizer(args)
435
+
436
+ # Build LoRA config if enabled
437
+ if args.use_lora:
438
+ logger.info("=" * 60)
439
+ logger.info("LoRA fine-tuning mode")
440
+ logger.info("=" * 60)
441
+ lora_config = get_lora_config(args)
442
+ else:
443
+ logger.info("=" * 60)
444
+ logger.info("Full-parameter fine-tuning mode")
445
+ logger.info("Warning: full fine-tuning needs more memory and time!")
446
+ logger.info("=" * 60)
447
+ lora_config = None
448
+
449
+ # SFTTrainer config
450
+ training_args = SFTConfig(
451
+ output_dir=output_dir,
452
+ run_name=args.run_name,
453
+
454
+ # Sequence length / packing
455
+ max_length=args.max_seq_length,
456
+ packing=False,
457
+
458
+ # Training
459
+ num_train_epochs=args.num_train_epochs,
460
+ max_steps=args.max_steps,
461
+ per_device_train_batch_size=args.per_device_train_batch_size,
462
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
463
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
464
+
465
+ # Optimizer
466
+ learning_rate=args.learning_rate,
467
+ weight_decay=args.weight_decay,
468
+ warmup_ratio=args.warmup_ratio,
469
+ lr_scheduler_type=args.lr_scheduler_type,
470
+ optim="adamw_torch_fused",
471
+
472
+ # Precision
473
+ bf16=args.bf16,
474
+ fp16=args.fp16,
475
+
476
+ # Logging / saving
477
+ logging_steps=args.logging_steps,
478
+ save_steps=args.save_steps,
479
+ eval_steps=args.eval_steps if eval_dataset else None,
480
+ eval_strategy="steps" if eval_dataset else "no",
481
+ save_total_limit=args.save_total_limit,
482
+ load_best_model_at_end=True if eval_dataset else False,
483
+
484
+ # Misc
485
+ seed=args.seed,
486
+ report_to=["tensorboard"],
487
+
488
+ # Hub
489
+ push_to_hub=args.push_to_hub,
490
+ hub_model_id=args.hub_model_id,
491
+
492
+ # Gradient checkpointing
493
+ gradient_checkpointing=args.gradient_checkpointing,
494
+ gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None,
495
+ )
496
+
497
+ # Create SFTTrainer
498
+ # Dataset should include 'messages' and 'tools'; SFTTrainer applies chat template automatically
499
+ trainer = SFTTrainer(
500
+ model=model,
501
+ args=training_args,
502
+ train_dataset=train_dataset,
503
+ eval_dataset=eval_dataset,
504
+ processing_class=tokenizer, # newer TRL uses processing_class instead of tokenizer
505
+ peft_config=lora_config,
506
+ )
507
+
508
+ # Parameter stats
509
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
510
+ total_params = sum(p.numel() for p in model.parameters())
511
+ trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0
512
+
513
+ logger.info("=" * 60)
514
+ logger.info("Model parameter stats:")
515
+ logger.info(f" Total params: {total_params:,}")
516
+ logger.info(f" Trainable params: {trainable_params:,}")
517
+ logger.info(f" Trainable ratio: {trainable_percentage:.2f}%")
518
+ logger.info(f" Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}")
519
+ logger.info("=" * 60)
520
+
521
+ # Train
522
+ logger.info("Start training...")
523
+
524
+ if args.resume_from_checkpoint:
525
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
526
+ else:
527
+ trainer.train()
528
+
529
+ # Save final model
530
+ logger.info("Saving final model...")
531
+ final_model_path = os.path.join(output_dir, "final_model")
532
+ trainer.save_model(final_model_path)
533
+ tokenizer.save_pretrained(final_model_path)
534
+
535
+ logger.info("=" * 60)
536
+ logger.info("Training done.")
537
+ logger.info(f"Model saved at: {final_model_path}")
538
+
539
+ if args.use_lora:
540
+ # LoRA: also save adapter
541
+ lora_path = os.path.join(output_dir, "lora_adapter")
542
+ model.save_pretrained(lora_path)
543
+ tokenizer.save_pretrained(lora_path)
544
+ logger.info(f"LoRA adapter saved to: {lora_path}")
545
+ logger.info("")
546
+ logger.info("Usage:")
547
+ logger.info(f" 1. LoRA adapter: {lora_path}")
548
+ logger.info(f" 2. Merge adapters with your base model before inference")
549
+ else:
550
+ # Full fine-tune: final_model is ready to use
551
+ logger.info("")
552
+ logger.info("Usage:")
553
+ logger.info(f" Use model directly from: {final_model_path}")
554
+
555
+ logger.info("=" * 60)
556
+
557
+
558
+ if __name__ == "__main__":
559
+ main()