""" SHIVIK-Code Model Implementation This is a modified version of SHIVIK-M4 with: - Extended context length (32K) via YaRN RoPE scaling - Tool calling capabilities - Fill-in-the-Middle support """ import torch import torch.nn as nn from transformers import LlamaForCausalLM, LlamaConfig from transformers.modeling_outputs import CausalLMOutputWithPast from typing import Optional, Tuple, List, Union class ShivikCodeConfig(LlamaConfig): """Configuration for SHIVIK-Code model.""" model_type = "shivik_code" def __init__( self, vocab_size=128279, # Extended for tool tokens hidden_size=2048, intermediate_size=8192, num_hidden_layers=16, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=32768, # Extended context rope_theta=500000.0, rope_scaling=None, **kwargs ): # Default YaRN scaling for 32K context if rope_scaling is None: rope_scaling = { "type": "yarn", "factor": 8.0, "original_max_position_embeddings": 4096, } super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, **kwargs ) # Tool token IDs (set after tokenizer is loaded) self.tool_call_start_id = None self.tool_call_end_id = None self.tool_result_start_id = None self.tool_result_end_id = None class ShivikCodeForCausalLM(LlamaForCausalLM): """ SHIVIK-Code: An agentic coding model. Extends LlamaForCausalLM with: - Tool calling support - Extended context via YaRN - FIM capability """ config_class = ShivikCodeConfig def __init__(self, config: ShivikCodeConfig): super().__init__(config) # Store tool token IDs for easy access self.tool_tokens = { "call_start": config.tool_call_start_id, "call_end": config.tool_call_end_id, "result_start": config.tool_result_start_id, "result_end": config.tool_result_end_id, } def is_tool_call(self, token_id: int) -> bool: """Check if token is a tool call token.""" return token_id in [ self.tool_tokens["call_start"], self.tool_tokens["call_end"], ] def generate_with_tools( self, input_ids: torch.Tensor, tool_executor, # Callable that executes tools max_new_tokens: int = 512, max_tool_calls: int = 10, **generate_kwargs ): """ Generate with automatic tool execution. Args: input_ids: Input token IDs tool_executor: Function that takes tool call JSON and returns result max_new_tokens: Max tokens per generation step max_tool_calls: Max number of tool calls allowed Returns: Full generated sequence including tool results """ current_ids = input_ids tool_call_count = 0 while tool_call_count < max_tool_calls: # Generate until tool call or EOS outputs = self.generate( current_ids, max_new_tokens=max_new_tokens, stop_strings=[""], **generate_kwargs ) generated = outputs[0] # Check if we hit a tool call if self._contains_tool_call(generated): # Extract and execute tool tool_call = self._extract_tool_call(generated) tool_result = tool_executor(tool_call) # Append tool result to context result_tokens = self._format_tool_result(tool_result) current_ids = torch.cat([generated, result_tokens], dim=-1) tool_call_count += 1 else: # No tool call, we're done return generated return current_ids def _contains_tool_call(self, token_ids: torch.Tensor) -> bool: """Check if sequence contains a tool call.""" # Implementation depends on tokenizer pass def _extract_tool_call(self, token_ids: torch.Tensor) -> dict: """Extract tool call JSON from sequence.""" # Implementation depends on tokenizer pass def _format_tool_result(self, result: str) -> torch.Tensor: """Format tool result as tokens.""" # Implementation depends on tokenizer pass # Register for auto classes from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("shivik_code", ShivikCodeConfig) AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM)