|
|
""" |
|
|
SHIVIK-Code Model Implementation |
|
|
|
|
|
This is a modified version of SHIVIK-M4 with: |
|
|
- Extended context length (32K) via YaRN RoPE scaling |
|
|
- Tool calling capabilities |
|
|
- Fill-in-the-Middle support |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import LlamaForCausalLM, LlamaConfig |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from typing import Optional, Tuple, List, Union |
|
|
|
|
|
|
|
|
class ShivikCodeConfig(LlamaConfig): |
|
|
"""Configuration for SHIVIK-Code model.""" |
|
|
|
|
|
model_type = "shivik_code" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=128279, |
|
|
hidden_size=2048, |
|
|
intermediate_size=8192, |
|
|
num_hidden_layers=16, |
|
|
num_attention_heads=32, |
|
|
num_key_value_heads=8, |
|
|
max_position_embeddings=32768, |
|
|
rope_theta=500000.0, |
|
|
rope_scaling=None, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
if rope_scaling is None: |
|
|
rope_scaling = { |
|
|
"type": "yarn", |
|
|
"factor": 8.0, |
|
|
"original_max_position_embeddings": 4096, |
|
|
} |
|
|
|
|
|
super().__init__( |
|
|
vocab_size=vocab_size, |
|
|
hidden_size=hidden_size, |
|
|
intermediate_size=intermediate_size, |
|
|
num_hidden_layers=num_hidden_layers, |
|
|
num_attention_heads=num_attention_heads, |
|
|
num_key_value_heads=num_key_value_heads, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
rope_theta=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.tool_call_start_id = None |
|
|
self.tool_call_end_id = None |
|
|
self.tool_result_start_id = None |
|
|
self.tool_result_end_id = None |
|
|
|
|
|
|
|
|
class ShivikCodeForCausalLM(LlamaForCausalLM): |
|
|
""" |
|
|
SHIVIK-Code: An agentic coding model. |
|
|
|
|
|
Extends LlamaForCausalLM with: |
|
|
- Tool calling support |
|
|
- Extended context via YaRN |
|
|
- FIM capability |
|
|
""" |
|
|
|
|
|
config_class = ShivikCodeConfig |
|
|
|
|
|
def __init__(self, config: ShivikCodeConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.tool_tokens = { |
|
|
"call_start": config.tool_call_start_id, |
|
|
"call_end": config.tool_call_end_id, |
|
|
"result_start": config.tool_result_start_id, |
|
|
"result_end": config.tool_result_end_id, |
|
|
} |
|
|
|
|
|
def is_tool_call(self, token_id: int) -> bool: |
|
|
"""Check if token is a tool call token.""" |
|
|
return token_id in [ |
|
|
self.tool_tokens["call_start"], |
|
|
self.tool_tokens["call_end"], |
|
|
] |
|
|
|
|
|
def generate_with_tools( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
tool_executor, |
|
|
max_new_tokens: int = 512, |
|
|
max_tool_calls: int = 10, |
|
|
**generate_kwargs |
|
|
): |
|
|
""" |
|
|
Generate with automatic tool execution. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs |
|
|
tool_executor: Function that takes tool call JSON and returns result |
|
|
max_new_tokens: Max tokens per generation step |
|
|
max_tool_calls: Max number of tool calls allowed |
|
|
|
|
|
Returns: |
|
|
Full generated sequence including tool results |
|
|
""" |
|
|
current_ids = input_ids |
|
|
tool_call_count = 0 |
|
|
|
|
|
while tool_call_count < max_tool_calls: |
|
|
|
|
|
outputs = self.generate( |
|
|
current_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
stop_strings=["</tool_call>"], |
|
|
**generate_kwargs |
|
|
) |
|
|
|
|
|
generated = outputs[0] |
|
|
|
|
|
|
|
|
if self._contains_tool_call(generated): |
|
|
|
|
|
tool_call = self._extract_tool_call(generated) |
|
|
tool_result = tool_executor(tool_call) |
|
|
|
|
|
|
|
|
result_tokens = self._format_tool_result(tool_result) |
|
|
current_ids = torch.cat([generated, result_tokens], dim=-1) |
|
|
tool_call_count += 1 |
|
|
else: |
|
|
|
|
|
return generated |
|
|
|
|
|
return current_ids |
|
|
|
|
|
def _contains_tool_call(self, token_ids: torch.Tensor) -> bool: |
|
|
"""Check if sequence contains a tool call.""" |
|
|
|
|
|
pass |
|
|
|
|
|
def _extract_tool_call(self, token_ids: torch.Tensor) -> dict: |
|
|
"""Extract tool call JSON from sequence.""" |
|
|
|
|
|
pass |
|
|
|
|
|
def _format_tool_result(self, result: str) -> torch.Tensor: |
|
|
"""Format tool result as tokens.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
AutoConfig.register("shivik_code", ShivikCodeConfig) |
|
|
AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM) |
|
|
|