""" Vibe Coding Module for MiniMind Max2 Fill-in-the-Middle (FIM) and intelligent code completion. """ from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import json import re import random @dataclass class CodeCompletionConfig: """Configuration for code completion and FIM.""" # FIM tokens fim_prefix_token: str = "" fim_middle_token: str = "" fim_suffix_token: str = "" fim_pad_token: str = "" # Code tokens code_start_token: str = "" code_end_token: str = "" # FIM training settings fim_rate: float = 0.5 # Probability of using FIM vs standard LM fim_spm_rate: float = 0.5 # Suffix-Prefix-Middle vs Prefix-Suffix-Middle # Context settings max_prefix_tokens: int = 4096 max_suffix_tokens: int = 2048 max_middle_tokens: int = 1024 # Language support supported_languages: List[str] = field(default_factory=lambda: [ "python", "javascript", "typescript", "rust", "go", "java", "cpp", "c" ]) # Code quality enforce_syntax: bool = True use_tree_sitter: bool = False # For syntax-aware completion class FIMTokenizer: """Handle Fill-in-the-Middle tokenization.""" def __init__(self, config: CodeCompletionConfig): self.config = config def create_fim_example( self, code: str, split_point: Optional[int] = None, mode: str = "PSM", # PSM or SPM ) -> Tuple[str, str]: """ Create a FIM training example from code. Args: code: Full code string split_point: Where to split (random if None) mode: PSM (Prefix-Suffix-Middle) or SPM (Suffix-Prefix-Middle) Returns: Tuple of (fim_input, target_middle) """ if split_point is None: # Random split point split_point = random.randint( len(code) // 4, 3 * len(code) // 4, ) # Find a good split point (end of line) while split_point < len(code) and code[split_point] != '\n': split_point += 1 # Determine middle span middle_start = split_point middle_end = min( middle_start + random.randint(50, 500), len(code), ) # Find end of middle span (end of line) while middle_end < len(code) and code[middle_end] != '\n': middle_end += 1 prefix = code[:middle_start] middle = code[middle_start:middle_end] suffix = code[middle_end:] cfg = self.config if mode == "PSM": # Prefix-Suffix-Middle fim_input = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}" else: # Suffix-Prefix-Middle fim_input = f"{cfg.fim_suffix_token}{suffix}{cfg.fim_prefix_token}{prefix}{cfg.fim_middle_token}" return fim_input, middle def format_completion_prompt( self, prefix: str, suffix: str = "", language: str = "python", ) -> str: """Format a completion prompt.""" cfg = self.config if suffix: # FIM mode prompt = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}" else: # Standard completion prompt = prefix return prompt class CodeProcessor: """Process code for training and inference.""" # Language-specific patterns LANGUAGE_PATTERNS = { "python": { "comment": r"#.*$", "docstring": r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'', "function": r"def\s+(\w+)\s*\(", "class": r"class\s+(\w+)\s*[:\(]", }, "javascript": { "comment": r"//.*$|/\*[\s\S]*?\*/", "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>", "class": r"class\s+(\w+)", }, "typescript": { "comment": r"//.*$|/\*[\s\S]*?\*/", "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>", "class": r"class\s+(\w+)", "interface": r"interface\s+(\w+)", }, "rust": { "comment": r"//.*$|/\*[\s\S]*?\*/", "function": r"fn\s+(\w+)", "struct": r"struct\s+(\w+)", "impl": r"impl\s+(\w+)", }, } @classmethod def detect_language(cls, code: str, filename: Optional[str] = None) -> str: """Detect programming language from code or filename.""" if filename: ext_map = { ".py": "python", ".js": "javascript", ".ts": "typescript", ".tsx": "typescript", ".rs": "rust", ".go": "go", ".java": "java", ".cpp": "cpp", ".c": "c", } for ext, lang in ext_map.items(): if filename.endswith(ext): return lang # Heuristic detection if "def " in code and "import " in code: return "python" if "function " in code or "const " in code: return "javascript" if "fn " in code and "let " in code: return "rust" return "python" # Default @classmethod def extract_context( cls, code: str, cursor_position: int, context_lines: int = 50, ) -> Tuple[str, str]: """Extract prefix and suffix around cursor position.""" lines = code.split('\n') # Find line number for cursor current_pos = 0 cursor_line = 0 for i, line in enumerate(lines): if current_pos + len(line) + 1 > cursor_position: cursor_line = i break current_pos += len(line) + 1 # Get context lines start_line = max(0, cursor_line - context_lines) end_line = min(len(lines), cursor_line + context_lines) prefix_lines = lines[start_line:cursor_line] suffix_lines = lines[cursor_line + 1:end_line] prefix = '\n'.join(prefix_lines) suffix = '\n'.join(suffix_lines) return prefix, suffix class FIMModule(nn.Module): """ Fill-in-the-Middle module for code completion. Enables intelligent middle-of-file completion. """ def __init__(self, config: CodeCompletionConfig, hidden_size: int): super().__init__() self.config = config self.hidden_size = hidden_size # FIM position embeddings self.fim_position_embed = nn.Embedding(3, hidden_size) # prefix, middle, suffix # Context combiner self.context_combiner = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size), ) # Completion quality predictor self.quality_predictor = nn.Sequential( nn.Linear(hidden_size, hidden_size // 4), nn.GELU(), nn.Linear(hidden_size // 4, 1), nn.Sigmoid(), ) # Tokenizer helper self.tokenizer = FIMTokenizer(config) self.processor = CodeProcessor() def forward( self, hidden_states: torch.Tensor, fim_positions: Optional[torch.Tensor] = None, prefix_mask: Optional[torch.Tensor] = None, suffix_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Process hidden states with FIM awareness. Args: hidden_states: [batch, seq_len, hidden_size] fim_positions: Position type for each token (0=prefix, 1=middle, 2=suffix) prefix_mask: Mask for prefix tokens suffix_mask: Mask for suffix tokens Returns: Enhanced hidden states and metrics """ batch_size, seq_len, _ = hidden_states.shape # Add FIM position embeddings if fim_positions is not None: pos_embed = self.fim_position_embed(fim_positions) hidden_states = hidden_states + pos_embed # Combine context from prefix and suffix if prefix_mask is not None and suffix_mask is not None: # Average pool prefix and suffix representations prefix_repr = (hidden_states * prefix_mask.unsqueeze(-1)).sum(1) / prefix_mask.sum(1, keepdim=True).clamp(min=1) suffix_repr = (hidden_states * suffix_mask.unsqueeze(-1)).sum(1) / suffix_mask.sum(1, keepdim=True).clamp(min=1) # Combine context = self.context_combiner(torch.cat([prefix_repr, suffix_repr], dim=-1)) # Add context to middle tokens middle_mask = ~(prefix_mask | suffix_mask) if middle_mask.any(): context_expanded = context.unsqueeze(1).expand(-1, seq_len, -1) hidden_states = hidden_states + context_expanded * middle_mask.unsqueeze(-1) # Quality prediction quality = self.quality_predictor(hidden_states.mean(1)) metrics = { "completion_quality": quality, } return hidden_states, metrics class VibeCoder: """ High-level interface for "vibe coding" - intuitive code assistance. """ def __init__( self, model: nn.Module, tokenizer, config: Optional[CodeCompletionConfig] = None, device: str = "cuda", ): self.model = model self.tokenizer = tokenizer self.config = config or CodeCompletionConfig() self.device = device # Get hidden size if hasattr(model, 'config'): hidden_size = model.config.hidden_size else: hidden_size = 1024 self.fim_module = FIMModule(self.config, hidden_size).to(device) self.fim_tokenizer = FIMTokenizer(self.config) def complete( self, prefix: str, suffix: str = "", max_tokens: int = 100, temperature: float = 0.2, stop_tokens: Optional[List[str]] = None, ) -> str: """ Complete code given prefix and optional suffix. Args: prefix: Code before cursor suffix: Code after cursor (for FIM) max_tokens: Maximum tokens to generate temperature: Sampling temperature stop_tokens: Tokens to stop generation Returns: Generated code completion """ self.model.eval() # Format prompt prompt = self.fim_tokenizer.format_completion_prompt(prefix, suffix) # Tokenize input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) # Generate with torch.no_grad(): generated = self.model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, do_sample=temperature > 0, top_p=0.95, ) # Decode completion = self.tokenizer.decode( generated[0][input_ids.shape[1]:], skip_special_tokens=True, ) # Stop at stop tokens if stop_tokens: for stop in stop_tokens: if stop in completion: completion = completion[:completion.index(stop)] return completion def complete_function( self, signature: str, context: str = "", language: str = "python", ) -> str: """Complete a function given its signature.""" if language == "python": prompt = f"{context}\n\n{signature}\n " elif language in ["javascript", "typescript"]: prompt = f"{context}\n\n{signature} {{\n " else: prompt = f"{context}\n\n{signature} {{\n " return self.complete(prompt, max_tokens=500) def explain_code(self, code: str, language: str = "python") -> str: """Generate explanation for code.""" prompt = f"# Explain the following {language} code:\n```{language}\n{code}\n```\n\n# Explanation:\n" return self.complete(prompt, max_tokens=300, temperature=0.3) def refactor( self, code: str, instruction: str = "Refactor this code to be cleaner and more efficient", language: str = "python", ) -> str: """Refactor code based on instruction.""" prompt = f"""# Original code: ```{language} {code} ``` # Task: {instruction} # Refactored code: ```{language} """ completion = self.complete(prompt, max_tokens=1000, temperature=0.2) # Clean up if "```" in completion: completion = completion[:completion.index("```")] return completion def fix_bug(self, code: str, error: str = "", language: str = "python") -> str: """Fix a bug in code.""" prompt = f"""# Buggy code: ```{language} {code} ``` # Error: {error if error else "Unknown bug"} # Fixed code: ```{language} """ completion = self.complete(prompt, max_tokens=1000, temperature=0.1) if "```" in completion: completion = completion[:completion.index("```")] return completion class CodeDataset(Dataset): """Dataset for code training with FIM.""" def __init__( self, data_path: str, tokenizer, config: CodeCompletionConfig, max_length: int = 2048, ): self.tokenizer = tokenizer self.config = config self.max_length = max_length self.fim_tokenizer = FIMTokenizer(config) self.examples = [] with open(data_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): self.examples.append(json.loads(line)) def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: example = self.examples[idx] code = example.get("code", example.get("content", "")) language = example.get("language", "python") # Decide FIM vs standard LM use_fim = random.random() < self.config.fim_rate if use_fim and len(code) > 100: # Create FIM example mode = "SPM" if random.random() < self.config.fim_spm_rate else "PSM" fim_input, target = self.fim_tokenizer.create_fim_example(code, mode=mode) text = fim_input + target else: # Standard LM text = code # Tokenize encodings = self.tokenizer( text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) return { "input_ids": encodings["input_ids"].squeeze(0), "attention_mask": encodings["attention_mask"].squeeze(0), "labels": encodings["input_ids"].squeeze(0), } def prepare_code_dataset( raw_data_path: str, output_path: str, languages: Optional[List[str]] = None, ) -> int: """Prepare code dataset for training.""" languages = languages or ["python", "javascript", "typescript", "rust"] processed = 0 with open(raw_data_path, 'r', encoding='utf-8') as fin, \ open(output_path, 'w', encoding='utf-8') as fout: for line in fin: if not line.strip(): continue data = json.loads(line) # Extract code and language code = data.get("code", data.get("content", "")) language = data.get("language", "") # Filter by language if languages and language not in languages: continue # Filter by quality (basic heuristics) if len(code) < 50 or len(code) > 100000: continue processed_example = { "code": code, "language": language, } fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n") processed += 1 return processed