Shivik-Code-1B / modeling_shivik_code.py
theaicompany02's picture
Initial release: Shivik-Code-1B
ee7c543 verified
"""
SHIVIK-Code Model Implementation
This is a modified version of SHIVIK-M4 with:
- Extended context length (32K) via YaRN RoPE scaling
- Tool calling capabilities
- Fill-in-the-Middle support
"""
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, List, Union
class ShivikCodeConfig(LlamaConfig):
"""Configuration for SHIVIK-Code model."""
model_type = "shivik_code"
def __init__(
self,
vocab_size=128279, # Extended for tool tokens
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=8,
max_position_embeddings=32768, # Extended context
rope_theta=500000.0,
rope_scaling=None,
**kwargs
):
# Default YaRN scaling for 32K context
if rope_scaling is None:
rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 4096,
}
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
**kwargs
)
# Tool token IDs (set after tokenizer is loaded)
self.tool_call_start_id = None
self.tool_call_end_id = None
self.tool_result_start_id = None
self.tool_result_end_id = None
class ShivikCodeForCausalLM(LlamaForCausalLM):
"""
SHIVIK-Code: An agentic coding model.
Extends LlamaForCausalLM with:
- Tool calling support
- Extended context via YaRN
- FIM capability
"""
config_class = ShivikCodeConfig
def __init__(self, config: ShivikCodeConfig):
super().__init__(config)
# Store tool token IDs for easy access
self.tool_tokens = {
"call_start": config.tool_call_start_id,
"call_end": config.tool_call_end_id,
"result_start": config.tool_result_start_id,
"result_end": config.tool_result_end_id,
}
def is_tool_call(self, token_id: int) -> bool:
"""Check if token is a tool call token."""
return token_id in [
self.tool_tokens["call_start"],
self.tool_tokens["call_end"],
]
def generate_with_tools(
self,
input_ids: torch.Tensor,
tool_executor, # Callable that executes tools
max_new_tokens: int = 512,
max_tool_calls: int = 10,
**generate_kwargs
):
"""
Generate with automatic tool execution.
Args:
input_ids: Input token IDs
tool_executor: Function that takes tool call JSON and returns result
max_new_tokens: Max tokens per generation step
max_tool_calls: Max number of tool calls allowed
Returns:
Full generated sequence including tool results
"""
current_ids = input_ids
tool_call_count = 0
while tool_call_count < max_tool_calls:
# Generate until tool call or EOS
outputs = self.generate(
current_ids,
max_new_tokens=max_new_tokens,
stop_strings=["</tool_call>"],
**generate_kwargs
)
generated = outputs[0]
# Check if we hit a tool call
if self._contains_tool_call(generated):
# Extract and execute tool
tool_call = self._extract_tool_call(generated)
tool_result = tool_executor(tool_call)
# Append tool result to context
result_tokens = self._format_tool_result(tool_result)
current_ids = torch.cat([generated, result_tokens], dim=-1)
tool_call_count += 1
else:
# No tool call, we're done
return generated
return current_ids
def _contains_tool_call(self, token_ids: torch.Tensor) -> bool:
"""Check if sequence contains a tool call."""
# Implementation depends on tokenizer
pass
def _extract_tool_call(self, token_ids: torch.Tensor) -> dict:
"""Extract tool call JSON from sequence."""
# Implementation depends on tokenizer
pass
def _format_tool_result(self, result: str) -> torch.Tensor:
"""Format tool result as tokens."""
# Implementation depends on tokenizer
pass
# Register for auto classes
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("shivik_code", ShivikCodeConfig)
AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM)