File size: 5,222 Bytes
ee7c543 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""
SHIVIK-Code Model Implementation
This is a modified version of SHIVIK-M4 with:
- Extended context length (32K) via YaRN RoPE scaling
- Tool calling capabilities
- Fill-in-the-Middle support
"""
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, List, Union
class ShivikCodeConfig(LlamaConfig):
"""Configuration for SHIVIK-Code model."""
model_type = "shivik_code"
def __init__(
self,
vocab_size=128279, # Extended for tool tokens
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=8,
max_position_embeddings=32768, # Extended context
rope_theta=500000.0,
rope_scaling=None,
**kwargs
):
# Default YaRN scaling for 32K context
if rope_scaling is None:
rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 4096,
}
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
**kwargs
)
# Tool token IDs (set after tokenizer is loaded)
self.tool_call_start_id = None
self.tool_call_end_id = None
self.tool_result_start_id = None
self.tool_result_end_id = None
class ShivikCodeForCausalLM(LlamaForCausalLM):
"""
SHIVIK-Code: An agentic coding model.
Extends LlamaForCausalLM with:
- Tool calling support
- Extended context via YaRN
- FIM capability
"""
config_class = ShivikCodeConfig
def __init__(self, config: ShivikCodeConfig):
super().__init__(config)
# Store tool token IDs for easy access
self.tool_tokens = {
"call_start": config.tool_call_start_id,
"call_end": config.tool_call_end_id,
"result_start": config.tool_result_start_id,
"result_end": config.tool_result_end_id,
}
def is_tool_call(self, token_id: int) -> bool:
"""Check if token is a tool call token."""
return token_id in [
self.tool_tokens["call_start"],
self.tool_tokens["call_end"],
]
def generate_with_tools(
self,
input_ids: torch.Tensor,
tool_executor, # Callable that executes tools
max_new_tokens: int = 512,
max_tool_calls: int = 10,
**generate_kwargs
):
"""
Generate with automatic tool execution.
Args:
input_ids: Input token IDs
tool_executor: Function that takes tool call JSON and returns result
max_new_tokens: Max tokens per generation step
max_tool_calls: Max number of tool calls allowed
Returns:
Full generated sequence including tool results
"""
current_ids = input_ids
tool_call_count = 0
while tool_call_count < max_tool_calls:
# Generate until tool call or EOS
outputs = self.generate(
current_ids,
max_new_tokens=max_new_tokens,
stop_strings=["</tool_call>"],
**generate_kwargs
)
generated = outputs[0]
# Check if we hit a tool call
if self._contains_tool_call(generated):
# Extract and execute tool
tool_call = self._extract_tool_call(generated)
tool_result = tool_executor(tool_call)
# Append tool result to context
result_tokens = self._format_tool_result(tool_result)
current_ids = torch.cat([generated, result_tokens], dim=-1)
tool_call_count += 1
else:
# No tool call, we're done
return generated
return current_ids
def _contains_tool_call(self, token_ids: torch.Tensor) -> bool:
"""Check if sequence contains a tool call."""
# Implementation depends on tokenizer
pass
def _extract_tool_call(self, token_ids: torch.Tensor) -> dict:
"""Extract tool call JSON from sequence."""
# Implementation depends on tokenizer
pass
def _format_tool_result(self, result: str) -> torch.Tensor:
"""Format tool result as tokens."""
# Implementation depends on tokenizer
pass
# Register for auto classes
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("shivik_code", ShivikCodeConfig)
AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM)
|