File size: 16,432 Bytes
ef553ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
"""
Vibe Coding Module for MiniMind Max2
Fill-in-the-Middle (FIM) and intelligent code completion.
"""

from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import re
import random


@dataclass
class CodeCompletionConfig:
    """Configuration for code completion and FIM."""
    # FIM tokens
    fim_prefix_token: str = "<fim_prefix>"
    fim_middle_token: str = "<fim_middle>"
    fim_suffix_token: str = "<fim_suffix>"
    fim_pad_token: str = "<fim_pad>"

    # Code tokens
    code_start_token: str = "<code>"
    code_end_token: str = "</code>"

    # FIM training settings
    fim_rate: float = 0.5  # Probability of using FIM vs standard LM
    fim_spm_rate: float = 0.5  # Suffix-Prefix-Middle vs Prefix-Suffix-Middle

    # Context settings
    max_prefix_tokens: int = 4096
    max_suffix_tokens: int = 2048
    max_middle_tokens: int = 1024

    # Language support
    supported_languages: List[str] = field(default_factory=lambda: [
        "python", "javascript", "typescript", "rust", "go", "java", "cpp", "c"
    ])

    # Code quality
    enforce_syntax: bool = True
    use_tree_sitter: bool = False  # For syntax-aware completion


class FIMTokenizer:
    """Handle Fill-in-the-Middle tokenization."""

    def __init__(self, config: CodeCompletionConfig):
        self.config = config

    def create_fim_example(
        self,
        code: str,
        split_point: Optional[int] = None,
        mode: str = "PSM",  # PSM or SPM
    ) -> Tuple[str, str]:
        """
        Create a FIM training example from code.

        Args:
            code: Full code string
            split_point: Where to split (random if None)
            mode: PSM (Prefix-Suffix-Middle) or SPM (Suffix-Prefix-Middle)

        Returns:
            Tuple of (fim_input, target_middle)
        """
        if split_point is None:
            # Random split point
            split_point = random.randint(
                len(code) // 4,
                3 * len(code) // 4,
            )

        # Find a good split point (end of line)
        while split_point < len(code) and code[split_point] != '\n':
            split_point += 1

        # Determine middle span
        middle_start = split_point
        middle_end = min(
            middle_start + random.randint(50, 500),
            len(code),
        )

        # Find end of middle span (end of line)
        while middle_end < len(code) and code[middle_end] != '\n':
            middle_end += 1

        prefix = code[:middle_start]
        middle = code[middle_start:middle_end]
        suffix = code[middle_end:]

        cfg = self.config

        if mode == "PSM":
            # Prefix-Suffix-Middle
            fim_input = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
        else:
            # Suffix-Prefix-Middle
            fim_input = f"{cfg.fim_suffix_token}{suffix}{cfg.fim_prefix_token}{prefix}{cfg.fim_middle_token}"

        return fim_input, middle

    def format_completion_prompt(
        self,
        prefix: str,
        suffix: str = "",
        language: str = "python",
    ) -> str:
        """Format a completion prompt."""
        cfg = self.config

        if suffix:
            # FIM mode
            prompt = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
        else:
            # Standard completion
            prompt = prefix

        return prompt


class CodeProcessor:
    """Process code for training and inference."""

    # Language-specific patterns
    LANGUAGE_PATTERNS = {
        "python": {
            "comment": r"#.*$",
            "docstring": r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'',
            "function": r"def\s+(\w+)\s*\(",
            "class": r"class\s+(\w+)\s*[:\(]",
        },
        "javascript": {
            "comment": r"//.*$|/\*[\s\S]*?\*/",
            "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
            "class": r"class\s+(\w+)",
        },
        "typescript": {
            "comment": r"//.*$|/\*[\s\S]*?\*/",
            "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
            "class": r"class\s+(\w+)",
            "interface": r"interface\s+(\w+)",
        },
        "rust": {
            "comment": r"//.*$|/\*[\s\S]*?\*/",
            "function": r"fn\s+(\w+)",
            "struct": r"struct\s+(\w+)",
            "impl": r"impl\s+(\w+)",
        },
    }

    @classmethod
    def detect_language(cls, code: str, filename: Optional[str] = None) -> str:
        """Detect programming language from code or filename."""
        if filename:
            ext_map = {
                ".py": "python",
                ".js": "javascript",
                ".ts": "typescript",
                ".tsx": "typescript",
                ".rs": "rust",
                ".go": "go",
                ".java": "java",
                ".cpp": "cpp",
                ".c": "c",
            }
            for ext, lang in ext_map.items():
                if filename.endswith(ext):
                    return lang

        # Heuristic detection
        if "def " in code and "import " in code:
            return "python"
        if "function " in code or "const " in code:
            return "javascript"
        if "fn " in code and "let " in code:
            return "rust"

        return "python"  # Default

    @classmethod
    def extract_context(
        cls,
        code: str,
        cursor_position: int,
        context_lines: int = 50,
    ) -> Tuple[str, str]:
        """Extract prefix and suffix around cursor position."""
        lines = code.split('\n')

        # Find line number for cursor
        current_pos = 0
        cursor_line = 0
        for i, line in enumerate(lines):
            if current_pos + len(line) + 1 > cursor_position:
                cursor_line = i
                break
            current_pos += len(line) + 1

        # Get context lines
        start_line = max(0, cursor_line - context_lines)
        end_line = min(len(lines), cursor_line + context_lines)

        prefix_lines = lines[start_line:cursor_line]
        suffix_lines = lines[cursor_line + 1:end_line]

        prefix = '\n'.join(prefix_lines)
        suffix = '\n'.join(suffix_lines)

        return prefix, suffix


class FIMModule(nn.Module):
    """
    Fill-in-the-Middle module for code completion.
    Enables intelligent middle-of-file completion.
    """

    def __init__(self, config: CodeCompletionConfig, hidden_size: int):
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size

        # FIM position embeddings
        self.fim_position_embed = nn.Embedding(3, hidden_size)  # prefix, middle, suffix

        # Context combiner
        self.context_combiner = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

        # Completion quality predictor
        self.quality_predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.GELU(),
            nn.Linear(hidden_size // 4, 1),
            nn.Sigmoid(),
        )

        # Tokenizer helper
        self.tokenizer = FIMTokenizer(config)
        self.processor = CodeProcessor()

    def forward(
        self,
        hidden_states: torch.Tensor,
        fim_positions: Optional[torch.Tensor] = None,
        prefix_mask: Optional[torch.Tensor] = None,
        suffix_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Process hidden states with FIM awareness.

        Args:
            hidden_states: [batch, seq_len, hidden_size]
            fim_positions: Position type for each token (0=prefix, 1=middle, 2=suffix)
            prefix_mask: Mask for prefix tokens
            suffix_mask: Mask for suffix tokens

        Returns:
            Enhanced hidden states and metrics
        """
        batch_size, seq_len, _ = hidden_states.shape

        # Add FIM position embeddings
        if fim_positions is not None:
            pos_embed = self.fim_position_embed(fim_positions)
            hidden_states = hidden_states + pos_embed

        # Combine context from prefix and suffix
        if prefix_mask is not None and suffix_mask is not None:
            # Average pool prefix and suffix representations
            prefix_repr = (hidden_states * prefix_mask.unsqueeze(-1)).sum(1) / prefix_mask.sum(1, keepdim=True).clamp(min=1)
            suffix_repr = (hidden_states * suffix_mask.unsqueeze(-1)).sum(1) / suffix_mask.sum(1, keepdim=True).clamp(min=1)

            # Combine
            context = self.context_combiner(torch.cat([prefix_repr, suffix_repr], dim=-1))

            # Add context to middle tokens
            middle_mask = ~(prefix_mask | suffix_mask)
            if middle_mask.any():
                context_expanded = context.unsqueeze(1).expand(-1, seq_len, -1)
                hidden_states = hidden_states + context_expanded * middle_mask.unsqueeze(-1)

        # Quality prediction
        quality = self.quality_predictor(hidden_states.mean(1))

        metrics = {
            "completion_quality": quality,
        }

        return hidden_states, metrics


class VibeCoder:
    """
    High-level interface for "vibe coding" - intuitive code assistance.
    """

    def __init__(
        self,
        model: nn.Module,
        tokenizer,
        config: Optional[CodeCompletionConfig] = None,
        device: str = "cuda",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config or CodeCompletionConfig()
        self.device = device

        # Get hidden size
        if hasattr(model, 'config'):
            hidden_size = model.config.hidden_size
        else:
            hidden_size = 1024

        self.fim_module = FIMModule(self.config, hidden_size).to(device)
        self.fim_tokenizer = FIMTokenizer(self.config)

    def complete(
        self,
        prefix: str,
        suffix: str = "",
        max_tokens: int = 100,
        temperature: float = 0.2,
        stop_tokens: Optional[List[str]] = None,
    ) -> str:
        """
        Complete code given prefix and optional suffix.

        Args:
            prefix: Code before cursor
            suffix: Code after cursor (for FIM)
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            stop_tokens: Tokens to stop generation

        Returns:
            Generated code completion
        """
        self.model.eval()

        # Format prompt
        prompt = self.fim_tokenizer.format_completion_prompt(prefix, suffix)

        # Tokenize
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)

        # Generate
        with torch.no_grad():
            generated = self.model.generate(
                input_ids,
                max_new_tokens=max_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                top_p=0.95,
            )

        # Decode
        completion = self.tokenizer.decode(
            generated[0][input_ids.shape[1]:],
            skip_special_tokens=True,
        )

        # Stop at stop tokens
        if stop_tokens:
            for stop in stop_tokens:
                if stop in completion:
                    completion = completion[:completion.index(stop)]

        return completion

    def complete_function(
        self,
        signature: str,
        context: str = "",
        language: str = "python",
    ) -> str:
        """Complete a function given its signature."""
        if language == "python":
            prompt = f"{context}\n\n{signature}\n    "
        elif language in ["javascript", "typescript"]:
            prompt = f"{context}\n\n{signature} {{\n    "
        else:
            prompt = f"{context}\n\n{signature} {{\n    "

        return self.complete(prompt, max_tokens=500)

    def explain_code(self, code: str, language: str = "python") -> str:
        """Generate explanation for code."""
        prompt = f"# Explain the following {language} code:\n```{language}\n{code}\n```\n\n# Explanation:\n"
        return self.complete(prompt, max_tokens=300, temperature=0.3)

    def refactor(
        self,
        code: str,
        instruction: str = "Refactor this code to be cleaner and more efficient",
        language: str = "python",
    ) -> str:
        """Refactor code based on instruction."""
        prompt = f"""# Original code:
```{language}
{code}
```

# Task: {instruction}

# Refactored code:
```{language}
"""
        completion = self.complete(prompt, max_tokens=1000, temperature=0.2)

        # Clean up
        if "```" in completion:
            completion = completion[:completion.index("```")]

        return completion

    def fix_bug(self, code: str, error: str = "", language: str = "python") -> str:
        """Fix a bug in code."""
        prompt = f"""# Buggy code:
```{language}
{code}
```

# Error: {error if error else "Unknown bug"}

# Fixed code:
```{language}
"""
        completion = self.complete(prompt, max_tokens=1000, temperature=0.1)

        if "```" in completion:
            completion = completion[:completion.index("```")]

        return completion


class CodeDataset(Dataset):
    """Dataset for code training with FIM."""

    def __init__(
        self,
        data_path: str,
        tokenizer,
        config: CodeCompletionConfig,
        max_length: int = 2048,
    ):
        self.tokenizer = tokenizer
        self.config = config
        self.max_length = max_length
        self.fim_tokenizer = FIMTokenizer(config)

        self.examples = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    self.examples.append(json.loads(line))

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        example = self.examples[idx]
        code = example.get("code", example.get("content", ""))
        language = example.get("language", "python")

        # Decide FIM vs standard LM
        use_fim = random.random() < self.config.fim_rate

        if use_fim and len(code) > 100:
            # Create FIM example
            mode = "SPM" if random.random() < self.config.fim_spm_rate else "PSM"
            fim_input, target = self.fim_tokenizer.create_fim_example(code, mode=mode)
            text = fim_input + target
        else:
            # Standard LM
            text = code

        # Tokenize
        encodings = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        return {
            "input_ids": encodings["input_ids"].squeeze(0),
            "attention_mask": encodings["attention_mask"].squeeze(0),
            "labels": encodings["input_ids"].squeeze(0),
        }


def prepare_code_dataset(
    raw_data_path: str,
    output_path: str,
    languages: Optional[List[str]] = None,
) -> int:
    """Prepare code dataset for training."""
    languages = languages or ["python", "javascript", "typescript", "rust"]
    processed = 0

    with open(raw_data_path, 'r', encoding='utf-8') as fin, \
         open(output_path, 'w', encoding='utf-8') as fout:

        for line in fin:
            if not line.strip():
                continue

            data = json.loads(line)

            # Extract code and language
            code = data.get("code", data.get("content", ""))
            language = data.get("language", "")

            # Filter by language
            if languages and language not in languages:
                continue

            # Filter by quality (basic heuristics)
            if len(code) < 50 or len(code) > 100000:
                continue

            processed_example = {
                "code": code,
                "language": language,
            }

            fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n")
            processed += 1

    return processed