fariasultana commited on
Commit
823ea46
·
verified ·
1 Parent(s): 826f659

feat: Add capabilities/agentic.py

Browse files
Files changed (1) hide show
  1. capabilities/agentic.py +471 -0
capabilities/agentic.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agentic Capabilities Module for MiniMind Max2
3
+ Function calling, tool use, and agent behaviors.
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Dict, Any, Callable, Union
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import json
12
+ import re
13
+ from enum import Enum
14
+
15
+
16
+ class ToolType(Enum):
17
+ """Types of tools/functions."""
18
+ FUNCTION = "function"
19
+ API = "api"
20
+ CODE_EXEC = "code_execution"
21
+ RETRIEVAL = "retrieval"
22
+ BROWSER = "browser"
23
+
24
+
25
+ @dataclass
26
+ class FunctionCallingConfig:
27
+ """Configuration for function calling."""
28
+ # Special tokens
29
+ tool_call_start: str = "<tool_call>"
30
+ tool_call_end: str = "</tool_call>"
31
+ tool_result_start: str = "<tool_result>"
32
+ tool_result_end: str = "</tool_result>"
33
+
34
+ # Behavior
35
+ max_tool_calls: int = 5
36
+ parallel_tool_calls: bool = True
37
+ strict_json: bool = True
38
+
39
+ # Training
40
+ function_calling_weight: float = 1.0
41
+ schema_embedding_dim: int = 256
42
+
43
+
44
+ @dataclass
45
+ class ToolDefinition:
46
+ """Definition of a callable tool."""
47
+ name: str
48
+ description: str
49
+ parameters: Dict[str, Any]
50
+ required: List[str] = field(default_factory=list)
51
+ tool_type: ToolType = ToolType.FUNCTION
52
+
53
+ def to_schema(self) -> Dict[str, Any]:
54
+ """Convert to JSON schema format."""
55
+ return {
56
+ "type": "function",
57
+ "function": {
58
+ "name": self.name,
59
+ "description": self.description,
60
+ "parameters": {
61
+ "type": "object",
62
+ "properties": self.parameters,
63
+ "required": self.required,
64
+ },
65
+ },
66
+ }
67
+
68
+ def to_prompt(self) -> str:
69
+ """Convert to prompt format for training."""
70
+ params_str = ", ".join([
71
+ f"{k}: {v.get('type', 'any')}"
72
+ for k, v in self.parameters.items()
73
+ ])
74
+ return f"{self.name}({params_str}) - {self.description}"
75
+
76
+
77
+ class ToolRegistry:
78
+ """Registry for managing available tools."""
79
+
80
+ def __init__(self):
81
+ self.tools: Dict[str, ToolDefinition] = {}
82
+ self.handlers: Dict[str, Callable] = {}
83
+
84
+ def register(
85
+ self,
86
+ name: str,
87
+ description: str,
88
+ parameters: Dict[str, Any],
89
+ required: Optional[List[str]] = None,
90
+ handler: Optional[Callable] = None,
91
+ tool_type: ToolType = ToolType.FUNCTION,
92
+ ) -> None:
93
+ """Register a new tool."""
94
+ self.tools[name] = ToolDefinition(
95
+ name=name,
96
+ description=description,
97
+ parameters=parameters,
98
+ required=required or [],
99
+ tool_type=tool_type,
100
+ )
101
+ if handler:
102
+ self.handlers[name] = handler
103
+
104
+ def get_tool(self, name: str) -> Optional[ToolDefinition]:
105
+ """Get tool definition by name."""
106
+ return self.tools.get(name)
107
+
108
+ def execute(self, name: str, **kwargs) -> Any:
109
+ """Execute a registered tool."""
110
+ if name not in self.handlers:
111
+ raise ValueError(f"No handler registered for tool: {name}")
112
+ return self.handlers[name](**kwargs)
113
+
114
+ def get_all_schemas(self) -> List[Dict[str, Any]]:
115
+ """Get all tool schemas."""
116
+ return [tool.to_schema() for tool in self.tools.values()]
117
+
118
+ def get_tools_prompt(self) -> str:
119
+ """Generate prompt describing all tools."""
120
+ tools_desc = "\n".join([
121
+ f"- {tool.to_prompt()}"
122
+ for tool in self.tools.values()
123
+ ])
124
+ return f"Available tools:\n{tools_desc}"
125
+
126
+
127
+ class ToolCallParser:
128
+ """Parse and validate tool calls from model output."""
129
+
130
+ def __init__(self, config: FunctionCallingConfig):
131
+ self.config = config
132
+
133
+ def extract_tool_calls(self, text: str) -> List[Dict[str, Any]]:
134
+ """Extract tool calls from model output."""
135
+ pattern = rf"{re.escape(self.config.tool_call_start)}(.*?){re.escape(self.config.tool_call_end)}"
136
+ matches = re.findall(pattern, text, re.DOTALL)
137
+
138
+ calls = []
139
+ for match in matches:
140
+ try:
141
+ call = json.loads(match.strip())
142
+ calls.append(call)
143
+ except json.JSONDecodeError:
144
+ # Try to parse as function call format
145
+ parsed = self._parse_function_format(match.strip())
146
+ if parsed:
147
+ calls.append(parsed)
148
+
149
+ return calls
150
+
151
+ def _parse_function_format(self, text: str) -> Optional[Dict[str, Any]]:
152
+ """Parse function(arg1=val1, arg2=val2) format."""
153
+ match = re.match(r"(\w+)\((.*)\)", text, re.DOTALL)
154
+ if not match:
155
+ return None
156
+
157
+ name = match.group(1)
158
+ args_str = match.group(2)
159
+
160
+ # Parse arguments
161
+ args = {}
162
+ for arg_match in re.finditer(r"(\w+)\s*=\s*([^,]+)", args_str):
163
+ key = arg_match.group(1)
164
+ value = arg_match.group(2).strip()
165
+
166
+ # Try to parse as JSON
167
+ try:
168
+ args[key] = json.loads(value)
169
+ except:
170
+ args[key] = value.strip('"\'')
171
+
172
+ return {"name": name, "arguments": args}
173
+
174
+ def format_tool_call(self, name: str, arguments: Dict[str, Any]) -> str:
175
+ """Format a tool call for output."""
176
+ call = {"name": name, "arguments": arguments}
177
+ return f"{self.config.tool_call_start}{json.dumps(call)}{self.config.tool_call_end}"
178
+
179
+ def format_tool_result(self, result: Any) -> str:
180
+ """Format a tool result for input."""
181
+ if isinstance(result, (dict, list)):
182
+ result_str = json.dumps(result)
183
+ else:
184
+ result_str = str(result)
185
+ return f"{self.config.tool_result_start}{result_str}{self.config.tool_result_end}"
186
+
187
+
188
+ class SchemaEncoder(nn.Module):
189
+ """Encode tool schemas for the model."""
190
+
191
+ def __init__(self, config: FunctionCallingConfig, hidden_size: int):
192
+ super().__init__()
193
+ self.config = config
194
+
195
+ # Simple schema encoder
196
+ self.encoder = nn.Sequential(
197
+ nn.Linear(config.schema_embedding_dim, hidden_size),
198
+ nn.GELU(),
199
+ nn.Linear(hidden_size, hidden_size),
200
+ )
201
+
202
+ # Schema embedding lookup (trainable)
203
+ self.schema_embeddings = nn.Embedding(1000, config.schema_embedding_dim)
204
+
205
+ def forward(self, schema_ids: torch.Tensor) -> torch.Tensor:
206
+ """Encode schema IDs to hidden representations."""
207
+ embeddings = self.schema_embeddings(schema_ids)
208
+ return self.encoder(embeddings)
209
+
210
+
211
+ class AgenticModule(nn.Module):
212
+ """
213
+ Agentic capabilities module for MiniMind Max2.
214
+ Handles function calling, tool use, and agent behaviors.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ config: FunctionCallingConfig,
220
+ hidden_size: int,
221
+ vocab_size: int,
222
+ ):
223
+ super().__init__()
224
+ self.config = config
225
+ self.hidden_size = hidden_size
226
+
227
+ # Tool call prediction head
228
+ self.tool_call_head = nn.Sequential(
229
+ nn.Linear(hidden_size, hidden_size // 2),
230
+ nn.GELU(),
231
+ nn.Linear(hidden_size // 2, 2), # [no_tool, use_tool]
232
+ )
233
+
234
+ # Tool selection head
235
+ self.tool_selector = nn.Sequential(
236
+ nn.Linear(hidden_size, hidden_size // 2),
237
+ nn.GELU(),
238
+ nn.Linear(hidden_size // 2, 100), # Max 100 tools
239
+ )
240
+
241
+ # Argument generation enhancement
242
+ self.arg_enhancer = nn.Linear(hidden_size, hidden_size)
243
+
244
+ # Schema encoder
245
+ self.schema_encoder = SchemaEncoder(config, hidden_size)
246
+
247
+ # Parser
248
+ self.parser = ToolCallParser(config)
249
+
250
+ # Registry
251
+ self.registry = ToolRegistry()
252
+
253
+ def should_call_tool(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ """Predict whether to call a tool at each position."""
255
+ return F.softmax(self.tool_call_head(hidden_states), dim=-1)
256
+
257
+ def select_tool(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ available_tools: Optional[List[str]] = None,
261
+ ) -> torch.Tensor:
262
+ """Select which tool to call."""
263
+ logits = self.tool_selector(hidden_states)
264
+
265
+ if available_tools is not None:
266
+ # Mask unavailable tools
267
+ num_tools = len(available_tools)
268
+ mask = torch.ones_like(logits) * float("-inf")
269
+ mask[..., :num_tools] = 0
270
+ logits = logits + mask
271
+
272
+ return F.softmax(logits, dim=-1)
273
+
274
+ def forward(
275
+ self,
276
+ hidden_states: torch.Tensor,
277
+ tool_labels: Optional[torch.Tensor] = None,
278
+ tool_ids: Optional[torch.Tensor] = None,
279
+ ) -> Dict[str, torch.Tensor]:
280
+ """
281
+ Process hidden states for agentic capabilities.
282
+
283
+ Returns:
284
+ Dictionary with tool predictions and losses
285
+ """
286
+ batch_size, seq_len, _ = hidden_states.shape
287
+
288
+ # Tool call predictions
289
+ tool_call_probs = self.should_call_tool(hidden_states)
290
+ tool_select_probs = self.select_tool(hidden_states)
291
+
292
+ # Enhanced hidden states for argument generation
293
+ enhanced = self.arg_enhancer(hidden_states)
294
+
295
+ outputs = {
296
+ "tool_call_probs": tool_call_probs,
297
+ "tool_select_probs": tool_select_probs,
298
+ "enhanced_hidden_states": enhanced,
299
+ }
300
+
301
+ # Compute losses if labels provided
302
+ if tool_labels is not None:
303
+ tool_call_loss = F.cross_entropy(
304
+ tool_call_probs.view(-1, 2),
305
+ tool_labels.view(-1),
306
+ ignore_index=-100,
307
+ )
308
+ outputs["tool_call_loss"] = tool_call_loss
309
+
310
+ if tool_ids is not None:
311
+ tool_select_loss = F.cross_entropy(
312
+ tool_select_probs.view(-1, tool_select_probs.shape[-1]),
313
+ tool_ids.view(-1),
314
+ ignore_index=-100,
315
+ )
316
+ outputs["tool_select_loss"] = tool_select_loss
317
+
318
+ return outputs
319
+
320
+ def generate_tool_call(
321
+ self,
322
+ model: nn.Module,
323
+ input_ids: torch.Tensor,
324
+ tools: List[ToolDefinition],
325
+ max_new_tokens: int = 100,
326
+ ) -> str:
327
+ """Generate a tool call from the model."""
328
+ # Add tools to prompt context
329
+ tools_prompt = "\n".join([t.to_prompt() for t in tools])
330
+
331
+ # Generate with tool awareness
332
+ # In practice, would modify generation to include tool tokens
333
+ generated = model.generate(
334
+ input_ids,
335
+ max_new_tokens=max_new_tokens,
336
+ )
337
+
338
+ # Extract any tool calls
339
+ output_text = "placeholder_output" # Would decode generated tokens
340
+ tool_calls = self.parser.extract_tool_calls(output_text)
341
+
342
+ return tool_calls
343
+
344
+
345
+ class AgenticTrainer:
346
+ """Trainer for agentic capabilities."""
347
+
348
+ def __init__(
349
+ self,
350
+ model: nn.Module,
351
+ agentic_module: AgenticModule,
352
+ config: FunctionCallingConfig,
353
+ learning_rate: float = 1e-5,
354
+ device: str = "cuda",
355
+ ):
356
+ self.model = model
357
+ self.agentic = agentic_module
358
+ self.config = config
359
+ self.device = device
360
+
361
+ # Only train agentic module
362
+ self.optimizer = torch.optim.AdamW(
363
+ agentic_module.parameters(),
364
+ lr=learning_rate,
365
+ )
366
+
367
+ def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
368
+ """Single training step."""
369
+ self.agentic.train()
370
+
371
+ input_ids = batch["input_ids"].to(self.device)
372
+ attention_mask = batch["attention_mask"].to(self.device)
373
+ tool_labels = batch.get("tool_labels")
374
+ tool_ids = batch.get("tool_ids")
375
+
376
+ if tool_labels is not None:
377
+ tool_labels = tool_labels.to(self.device)
378
+ if tool_ids is not None:
379
+ tool_ids = tool_ids.to(self.device)
380
+
381
+ # Get hidden states from frozen model
382
+ with torch.no_grad():
383
+ if hasattr(self.model, 'model'):
384
+ hidden_states, _, _ = self.model.model(input_ids, attention_mask)
385
+ else:
386
+ hidden_states = self.model.embed_tokens(input_ids)
387
+
388
+ # Agentic forward
389
+ outputs = self.agentic(hidden_states, tool_labels, tool_ids)
390
+
391
+ # Total loss
392
+ loss = torch.tensor(0.0, device=self.device)
393
+ if "tool_call_loss" in outputs:
394
+ loss = loss + outputs["tool_call_loss"]
395
+ if "tool_select_loss" in outputs:
396
+ loss = loss + outputs["tool_select_loss"]
397
+
398
+ # Backward
399
+ self.optimizer.zero_grad()
400
+ loss.backward()
401
+ self.optimizer.step()
402
+
403
+ return {
404
+ "loss": loss.item(),
405
+ "tool_call_loss": outputs.get("tool_call_loss", torch.tensor(0.0)).item(),
406
+ "tool_select_loss": outputs.get("tool_select_loss", torch.tensor(0.0)).item(),
407
+ }
408
+
409
+
410
+ # Pre-defined common tools
411
+ DEFAULT_TOOLS = [
412
+ ToolDefinition(
413
+ name="search",
414
+ description="Search the web for information",
415
+ parameters={
416
+ "query": {"type": "string", "description": "Search query"},
417
+ },
418
+ required=["query"],
419
+ ),
420
+ ToolDefinition(
421
+ name="calculate",
422
+ description="Perform mathematical calculations",
423
+ parameters={
424
+ "expression": {"type": "string", "description": "Math expression to evaluate"},
425
+ },
426
+ required=["expression"],
427
+ ),
428
+ ToolDefinition(
429
+ name="get_weather",
430
+ description="Get current weather for a location",
431
+ parameters={
432
+ "location": {"type": "string", "description": "City name or coordinates"},
433
+ },
434
+ required=["location"],
435
+ ),
436
+ ToolDefinition(
437
+ name="run_code",
438
+ description="Execute Python code",
439
+ parameters={
440
+ "code": {"type": "string", "description": "Python code to execute"},
441
+ "language": {"type": "string", "description": "Programming language", "default": "python"},
442
+ },
443
+ required=["code"],
444
+ tool_type=ToolType.CODE_EXEC,
445
+ ),
446
+ ToolDefinition(
447
+ name="read_file",
448
+ description="Read contents of a file",
449
+ parameters={
450
+ "path": {"type": "string", "description": "File path"},
451
+ },
452
+ required=["path"],
453
+ ),
454
+ ToolDefinition(
455
+ name="write_file",
456
+ description="Write contents to a file",
457
+ parameters={
458
+ "path": {"type": "string", "description": "File path"},
459
+ "content": {"type": "string", "description": "Content to write"},
460
+ },
461
+ required=["path", "content"],
462
+ ),
463
+ ]
464
+
465
+
466
+ def create_agentic_registry() -> ToolRegistry:
467
+ """Create a registry with default tools."""
468
+ registry = ToolRegistry()
469
+ for tool in DEFAULT_TOOLS:
470
+ registry.tools[tool.name] = tool
471
+ return registry