File size: 5,222 Bytes
ee7c543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
SHIVIK-Code Model Implementation

This is a modified version of SHIVIK-M4 with:
- Extended context length (32K) via YaRN RoPE scaling
- Tool calling capabilities
- Fill-in-the-Middle support
"""

import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, List, Union


class ShivikCodeConfig(LlamaConfig):
    """Configuration for SHIVIK-Code model."""
    
    model_type = "shivik_code"
    
    def __init__(
        self,
        vocab_size=128279,  # Extended for tool tokens
        hidden_size=2048,
        intermediate_size=8192,
        num_hidden_layers=16,
        num_attention_heads=32,
        num_key_value_heads=8,
        max_position_embeddings=32768,  # Extended context
        rope_theta=500000.0,
        rope_scaling=None,
        **kwargs
    ):
        # Default YaRN scaling for 32K context
        if rope_scaling is None:
            rope_scaling = {
                "type": "yarn",
                "factor": 8.0,
                "original_max_position_embeddings": 4096,
            }
        
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            **kwargs
        )
        
        # Tool token IDs (set after tokenizer is loaded)
        self.tool_call_start_id = None
        self.tool_call_end_id = None
        self.tool_result_start_id = None
        self.tool_result_end_id = None


class ShivikCodeForCausalLM(LlamaForCausalLM):
    """
    SHIVIK-Code: An agentic coding model.
    
    Extends LlamaForCausalLM with:
    - Tool calling support
    - Extended context via YaRN
    - FIM capability
    """
    
    config_class = ShivikCodeConfig
    
    def __init__(self, config: ShivikCodeConfig):
        super().__init__(config)
        
        # Store tool token IDs for easy access
        self.tool_tokens = {
            "call_start": config.tool_call_start_id,
            "call_end": config.tool_call_end_id,
            "result_start": config.tool_result_start_id,
            "result_end": config.tool_result_end_id,
        }
    
    def is_tool_call(self, token_id: int) -> bool:
        """Check if token is a tool call token."""
        return token_id in [
            self.tool_tokens["call_start"],
            self.tool_tokens["call_end"],
        ]
    
    def generate_with_tools(
        self,
        input_ids: torch.Tensor,
        tool_executor,  # Callable that executes tools
        max_new_tokens: int = 512,
        max_tool_calls: int = 10,
        **generate_kwargs
    ):
        """
        Generate with automatic tool execution.
        
        Args:
            input_ids: Input token IDs
            tool_executor: Function that takes tool call JSON and returns result
            max_new_tokens: Max tokens per generation step
            max_tool_calls: Max number of tool calls allowed
        
        Returns:
            Full generated sequence including tool results
        """
        current_ids = input_ids
        tool_call_count = 0
        
        while tool_call_count < max_tool_calls:
            # Generate until tool call or EOS
            outputs = self.generate(
                current_ids,
                max_new_tokens=max_new_tokens,
                stop_strings=["</tool_call>"],
                **generate_kwargs
            )
            
            generated = outputs[0]
            
            # Check if we hit a tool call
            if self._contains_tool_call(generated):
                # Extract and execute tool
                tool_call = self._extract_tool_call(generated)
                tool_result = tool_executor(tool_call)
                
                # Append tool result to context
                result_tokens = self._format_tool_result(tool_result)
                current_ids = torch.cat([generated, result_tokens], dim=-1)
                tool_call_count += 1
            else:
                # No tool call, we're done
                return generated
        
        return current_ids
    
    def _contains_tool_call(self, token_ids: torch.Tensor) -> bool:
        """Check if sequence contains a tool call."""
        # Implementation depends on tokenizer
        pass
    
    def _extract_tool_call(self, token_ids: torch.Tensor) -> dict:
        """Extract tool call JSON from sequence."""
        # Implementation depends on tokenizer
        pass
    
    def _format_tool_result(self, result: str) -> torch.Tensor:
        """Format tool result as tokens."""
        # Implementation depends on tokenizer
        pass


# Register for auto classes
from transformers import AutoConfig, AutoModelForCausalLM

AutoConfig.register("shivik_code", ShivikCodeConfig)
AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM)