fariasultana commited on
Commit
ef553ca
·
verified ·
1 Parent(s): 823ea46

feat: Add capabilities/coding.py

Browse files
Files changed (1) hide show
  1. capabilities/coding.py +543 -0
capabilities/coding.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vibe Coding Module for MiniMind Max2
3
+ Fill-in-the-Middle (FIM) and intelligent code completion.
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Dict, Any, Tuple
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import json
13
+ import re
14
+ import random
15
+
16
+
17
+ @dataclass
18
+ class CodeCompletionConfig:
19
+ """Configuration for code completion and FIM."""
20
+ # FIM tokens
21
+ fim_prefix_token: str = "<fim_prefix>"
22
+ fim_middle_token: str = "<fim_middle>"
23
+ fim_suffix_token: str = "<fim_suffix>"
24
+ fim_pad_token: str = "<fim_pad>"
25
+
26
+ # Code tokens
27
+ code_start_token: str = "<code>"
28
+ code_end_token: str = "</code>"
29
+
30
+ # FIM training settings
31
+ fim_rate: float = 0.5 # Probability of using FIM vs standard LM
32
+ fim_spm_rate: float = 0.5 # Suffix-Prefix-Middle vs Prefix-Suffix-Middle
33
+
34
+ # Context settings
35
+ max_prefix_tokens: int = 4096
36
+ max_suffix_tokens: int = 2048
37
+ max_middle_tokens: int = 1024
38
+
39
+ # Language support
40
+ supported_languages: List[str] = field(default_factory=lambda: [
41
+ "python", "javascript", "typescript", "rust", "go", "java", "cpp", "c"
42
+ ])
43
+
44
+ # Code quality
45
+ enforce_syntax: bool = True
46
+ use_tree_sitter: bool = False # For syntax-aware completion
47
+
48
+
49
+ class FIMTokenizer:
50
+ """Handle Fill-in-the-Middle tokenization."""
51
+
52
+ def __init__(self, config: CodeCompletionConfig):
53
+ self.config = config
54
+
55
+ def create_fim_example(
56
+ self,
57
+ code: str,
58
+ split_point: Optional[int] = None,
59
+ mode: str = "PSM", # PSM or SPM
60
+ ) -> Tuple[str, str]:
61
+ """
62
+ Create a FIM training example from code.
63
+
64
+ Args:
65
+ code: Full code string
66
+ split_point: Where to split (random if None)
67
+ mode: PSM (Prefix-Suffix-Middle) or SPM (Suffix-Prefix-Middle)
68
+
69
+ Returns:
70
+ Tuple of (fim_input, target_middle)
71
+ """
72
+ if split_point is None:
73
+ # Random split point
74
+ split_point = random.randint(
75
+ len(code) // 4,
76
+ 3 * len(code) // 4,
77
+ )
78
+
79
+ # Find a good split point (end of line)
80
+ while split_point < len(code) and code[split_point] != '\n':
81
+ split_point += 1
82
+
83
+ # Determine middle span
84
+ middle_start = split_point
85
+ middle_end = min(
86
+ middle_start + random.randint(50, 500),
87
+ len(code),
88
+ )
89
+
90
+ # Find end of middle span (end of line)
91
+ while middle_end < len(code) and code[middle_end] != '\n':
92
+ middle_end += 1
93
+
94
+ prefix = code[:middle_start]
95
+ middle = code[middle_start:middle_end]
96
+ suffix = code[middle_end:]
97
+
98
+ cfg = self.config
99
+
100
+ if mode == "PSM":
101
+ # Prefix-Suffix-Middle
102
+ fim_input = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
103
+ else:
104
+ # Suffix-Prefix-Middle
105
+ fim_input = f"{cfg.fim_suffix_token}{suffix}{cfg.fim_prefix_token}{prefix}{cfg.fim_middle_token}"
106
+
107
+ return fim_input, middle
108
+
109
+ def format_completion_prompt(
110
+ self,
111
+ prefix: str,
112
+ suffix: str = "",
113
+ language: str = "python",
114
+ ) -> str:
115
+ """Format a completion prompt."""
116
+ cfg = self.config
117
+
118
+ if suffix:
119
+ # FIM mode
120
+ prompt = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
121
+ else:
122
+ # Standard completion
123
+ prompt = prefix
124
+
125
+ return prompt
126
+
127
+
128
+ class CodeProcessor:
129
+ """Process code for training and inference."""
130
+
131
+ # Language-specific patterns
132
+ LANGUAGE_PATTERNS = {
133
+ "python": {
134
+ "comment": r"#.*$",
135
+ "docstring": r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'',
136
+ "function": r"def\s+(\w+)\s*\(",
137
+ "class": r"class\s+(\w+)\s*[:\(]",
138
+ },
139
+ "javascript": {
140
+ "comment": r"//.*$|/\*[\s\S]*?\*/",
141
+ "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
142
+ "class": r"class\s+(\w+)",
143
+ },
144
+ "typescript": {
145
+ "comment": r"//.*$|/\*[\s\S]*?\*/",
146
+ "function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
147
+ "class": r"class\s+(\w+)",
148
+ "interface": r"interface\s+(\w+)",
149
+ },
150
+ "rust": {
151
+ "comment": r"//.*$|/\*[\s\S]*?\*/",
152
+ "function": r"fn\s+(\w+)",
153
+ "struct": r"struct\s+(\w+)",
154
+ "impl": r"impl\s+(\w+)",
155
+ },
156
+ }
157
+
158
+ @classmethod
159
+ def detect_language(cls, code: str, filename: Optional[str] = None) -> str:
160
+ """Detect programming language from code or filename."""
161
+ if filename:
162
+ ext_map = {
163
+ ".py": "python",
164
+ ".js": "javascript",
165
+ ".ts": "typescript",
166
+ ".tsx": "typescript",
167
+ ".rs": "rust",
168
+ ".go": "go",
169
+ ".java": "java",
170
+ ".cpp": "cpp",
171
+ ".c": "c",
172
+ }
173
+ for ext, lang in ext_map.items():
174
+ if filename.endswith(ext):
175
+ return lang
176
+
177
+ # Heuristic detection
178
+ if "def " in code and "import " in code:
179
+ return "python"
180
+ if "function " in code or "const " in code:
181
+ return "javascript"
182
+ if "fn " in code and "let " in code:
183
+ return "rust"
184
+
185
+ return "python" # Default
186
+
187
+ @classmethod
188
+ def extract_context(
189
+ cls,
190
+ code: str,
191
+ cursor_position: int,
192
+ context_lines: int = 50,
193
+ ) -> Tuple[str, str]:
194
+ """Extract prefix and suffix around cursor position."""
195
+ lines = code.split('\n')
196
+
197
+ # Find line number for cursor
198
+ current_pos = 0
199
+ cursor_line = 0
200
+ for i, line in enumerate(lines):
201
+ if current_pos + len(line) + 1 > cursor_position:
202
+ cursor_line = i
203
+ break
204
+ current_pos += len(line) + 1
205
+
206
+ # Get context lines
207
+ start_line = max(0, cursor_line - context_lines)
208
+ end_line = min(len(lines), cursor_line + context_lines)
209
+
210
+ prefix_lines = lines[start_line:cursor_line]
211
+ suffix_lines = lines[cursor_line + 1:end_line]
212
+
213
+ prefix = '\n'.join(prefix_lines)
214
+ suffix = '\n'.join(suffix_lines)
215
+
216
+ return prefix, suffix
217
+
218
+
219
+ class FIMModule(nn.Module):
220
+ """
221
+ Fill-in-the-Middle module for code completion.
222
+ Enables intelligent middle-of-file completion.
223
+ """
224
+
225
+ def __init__(self, config: CodeCompletionConfig, hidden_size: int):
226
+ super().__init__()
227
+ self.config = config
228
+ self.hidden_size = hidden_size
229
+
230
+ # FIM position embeddings
231
+ self.fim_position_embed = nn.Embedding(3, hidden_size) # prefix, middle, suffix
232
+
233
+ # Context combiner
234
+ self.context_combiner = nn.Sequential(
235
+ nn.Linear(hidden_size * 2, hidden_size),
236
+ nn.GELU(),
237
+ nn.Linear(hidden_size, hidden_size),
238
+ )
239
+
240
+ # Completion quality predictor
241
+ self.quality_predictor = nn.Sequential(
242
+ nn.Linear(hidden_size, hidden_size // 4),
243
+ nn.GELU(),
244
+ nn.Linear(hidden_size // 4, 1),
245
+ nn.Sigmoid(),
246
+ )
247
+
248
+ # Tokenizer helper
249
+ self.tokenizer = FIMTokenizer(config)
250
+ self.processor = CodeProcessor()
251
+
252
+ def forward(
253
+ self,
254
+ hidden_states: torch.Tensor,
255
+ fim_positions: Optional[torch.Tensor] = None,
256
+ prefix_mask: Optional[torch.Tensor] = None,
257
+ suffix_mask: Optional[torch.Tensor] = None,
258
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
259
+ """
260
+ Process hidden states with FIM awareness.
261
+
262
+ Args:
263
+ hidden_states: [batch, seq_len, hidden_size]
264
+ fim_positions: Position type for each token (0=prefix, 1=middle, 2=suffix)
265
+ prefix_mask: Mask for prefix tokens
266
+ suffix_mask: Mask for suffix tokens
267
+
268
+ Returns:
269
+ Enhanced hidden states and metrics
270
+ """
271
+ batch_size, seq_len, _ = hidden_states.shape
272
+
273
+ # Add FIM position embeddings
274
+ if fim_positions is not None:
275
+ pos_embed = self.fim_position_embed(fim_positions)
276
+ hidden_states = hidden_states + pos_embed
277
+
278
+ # Combine context from prefix and suffix
279
+ if prefix_mask is not None and suffix_mask is not None:
280
+ # Average pool prefix and suffix representations
281
+ prefix_repr = (hidden_states * prefix_mask.unsqueeze(-1)).sum(1) / prefix_mask.sum(1, keepdim=True).clamp(min=1)
282
+ suffix_repr = (hidden_states * suffix_mask.unsqueeze(-1)).sum(1) / suffix_mask.sum(1, keepdim=True).clamp(min=1)
283
+
284
+ # Combine
285
+ context = self.context_combiner(torch.cat([prefix_repr, suffix_repr], dim=-1))
286
+
287
+ # Add context to middle tokens
288
+ middle_mask = ~(prefix_mask | suffix_mask)
289
+ if middle_mask.any():
290
+ context_expanded = context.unsqueeze(1).expand(-1, seq_len, -1)
291
+ hidden_states = hidden_states + context_expanded * middle_mask.unsqueeze(-1)
292
+
293
+ # Quality prediction
294
+ quality = self.quality_predictor(hidden_states.mean(1))
295
+
296
+ metrics = {
297
+ "completion_quality": quality,
298
+ }
299
+
300
+ return hidden_states, metrics
301
+
302
+
303
+ class VibeCoder:
304
+ """
305
+ High-level interface for "vibe coding" - intuitive code assistance.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ model: nn.Module,
311
+ tokenizer,
312
+ config: Optional[CodeCompletionConfig] = None,
313
+ device: str = "cuda",
314
+ ):
315
+ self.model = model
316
+ self.tokenizer = tokenizer
317
+ self.config = config or CodeCompletionConfig()
318
+ self.device = device
319
+
320
+ # Get hidden size
321
+ if hasattr(model, 'config'):
322
+ hidden_size = model.config.hidden_size
323
+ else:
324
+ hidden_size = 1024
325
+
326
+ self.fim_module = FIMModule(self.config, hidden_size).to(device)
327
+ self.fim_tokenizer = FIMTokenizer(self.config)
328
+
329
+ def complete(
330
+ self,
331
+ prefix: str,
332
+ suffix: str = "",
333
+ max_tokens: int = 100,
334
+ temperature: float = 0.2,
335
+ stop_tokens: Optional[List[str]] = None,
336
+ ) -> str:
337
+ """
338
+ Complete code given prefix and optional suffix.
339
+
340
+ Args:
341
+ prefix: Code before cursor
342
+ suffix: Code after cursor (for FIM)
343
+ max_tokens: Maximum tokens to generate
344
+ temperature: Sampling temperature
345
+ stop_tokens: Tokens to stop generation
346
+
347
+ Returns:
348
+ Generated code completion
349
+ """
350
+ self.model.eval()
351
+
352
+ # Format prompt
353
+ prompt = self.fim_tokenizer.format_completion_prompt(prefix, suffix)
354
+
355
+ # Tokenize
356
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
357
+
358
+ # Generate
359
+ with torch.no_grad():
360
+ generated = self.model.generate(
361
+ input_ids,
362
+ max_new_tokens=max_tokens,
363
+ temperature=temperature,
364
+ do_sample=temperature > 0,
365
+ top_p=0.95,
366
+ )
367
+
368
+ # Decode
369
+ completion = self.tokenizer.decode(
370
+ generated[0][input_ids.shape[1]:],
371
+ skip_special_tokens=True,
372
+ )
373
+
374
+ # Stop at stop tokens
375
+ if stop_tokens:
376
+ for stop in stop_tokens:
377
+ if stop in completion:
378
+ completion = completion[:completion.index(stop)]
379
+
380
+ return completion
381
+
382
+ def complete_function(
383
+ self,
384
+ signature: str,
385
+ context: str = "",
386
+ language: str = "python",
387
+ ) -> str:
388
+ """Complete a function given its signature."""
389
+ if language == "python":
390
+ prompt = f"{context}\n\n{signature}\n "
391
+ elif language in ["javascript", "typescript"]:
392
+ prompt = f"{context}\n\n{signature} {{\n "
393
+ else:
394
+ prompt = f"{context}\n\n{signature} {{\n "
395
+
396
+ return self.complete(prompt, max_tokens=500)
397
+
398
+ def explain_code(self, code: str, language: str = "python") -> str:
399
+ """Generate explanation for code."""
400
+ prompt = f"# Explain the following {language} code:\n```{language}\n{code}\n```\n\n# Explanation:\n"
401
+ return self.complete(prompt, max_tokens=300, temperature=0.3)
402
+
403
+ def refactor(
404
+ self,
405
+ code: str,
406
+ instruction: str = "Refactor this code to be cleaner and more efficient",
407
+ language: str = "python",
408
+ ) -> str:
409
+ """Refactor code based on instruction."""
410
+ prompt = f"""# Original code:
411
+ ```{language}
412
+ {code}
413
+ ```
414
+
415
+ # Task: {instruction}
416
+
417
+ # Refactored code:
418
+ ```{language}
419
+ """
420
+ completion = self.complete(prompt, max_tokens=1000, temperature=0.2)
421
+
422
+ # Clean up
423
+ if "```" in completion:
424
+ completion = completion[:completion.index("```")]
425
+
426
+ return completion
427
+
428
+ def fix_bug(self, code: str, error: str = "", language: str = "python") -> str:
429
+ """Fix a bug in code."""
430
+ prompt = f"""# Buggy code:
431
+ ```{language}
432
+ {code}
433
+ ```
434
+
435
+ # Error: {error if error else "Unknown bug"}
436
+
437
+ # Fixed code:
438
+ ```{language}
439
+ """
440
+ completion = self.complete(prompt, max_tokens=1000, temperature=0.1)
441
+
442
+ if "```" in completion:
443
+ completion = completion[:completion.index("```")]
444
+
445
+ return completion
446
+
447
+
448
+ class CodeDataset(Dataset):
449
+ """Dataset for code training with FIM."""
450
+
451
+ def __init__(
452
+ self,
453
+ data_path: str,
454
+ tokenizer,
455
+ config: CodeCompletionConfig,
456
+ max_length: int = 2048,
457
+ ):
458
+ self.tokenizer = tokenizer
459
+ self.config = config
460
+ self.max_length = max_length
461
+ self.fim_tokenizer = FIMTokenizer(config)
462
+
463
+ self.examples = []
464
+ with open(data_path, 'r', encoding='utf-8') as f:
465
+ for line in f:
466
+ if line.strip():
467
+ self.examples.append(json.loads(line))
468
+
469
+ def __len__(self) -> int:
470
+ return len(self.examples)
471
+
472
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
473
+ example = self.examples[idx]
474
+ code = example.get("code", example.get("content", ""))
475
+ language = example.get("language", "python")
476
+
477
+ # Decide FIM vs standard LM
478
+ use_fim = random.random() < self.config.fim_rate
479
+
480
+ if use_fim and len(code) > 100:
481
+ # Create FIM example
482
+ mode = "SPM" if random.random() < self.config.fim_spm_rate else "PSM"
483
+ fim_input, target = self.fim_tokenizer.create_fim_example(code, mode=mode)
484
+ text = fim_input + target
485
+ else:
486
+ # Standard LM
487
+ text = code
488
+
489
+ # Tokenize
490
+ encodings = self.tokenizer(
491
+ text,
492
+ max_length=self.max_length,
493
+ truncation=True,
494
+ padding="max_length",
495
+ return_tensors="pt",
496
+ )
497
+
498
+ return {
499
+ "input_ids": encodings["input_ids"].squeeze(0),
500
+ "attention_mask": encodings["attention_mask"].squeeze(0),
501
+ "labels": encodings["input_ids"].squeeze(0),
502
+ }
503
+
504
+
505
+ def prepare_code_dataset(
506
+ raw_data_path: str,
507
+ output_path: str,
508
+ languages: Optional[List[str]] = None,
509
+ ) -> int:
510
+ """Prepare code dataset for training."""
511
+ languages = languages or ["python", "javascript", "typescript", "rust"]
512
+ processed = 0
513
+
514
+ with open(raw_data_path, 'r', encoding='utf-8') as fin, \
515
+ open(output_path, 'w', encoding='utf-8') as fout:
516
+
517
+ for line in fin:
518
+ if not line.strip():
519
+ continue
520
+
521
+ data = json.loads(line)
522
+
523
+ # Extract code and language
524
+ code = data.get("code", data.get("content", ""))
525
+ language = data.get("language", "")
526
+
527
+ # Filter by language
528
+ if languages and language not in languages:
529
+ continue
530
+
531
+ # Filter by quality (basic heuristics)
532
+ if len(code) < 50 or len(code) > 100000:
533
+ continue
534
+
535
+ processed_example = {
536
+ "code": code,
537
+ "language": language,
538
+ }
539
+
540
+ fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n")
541
+ processed += 1
542
+
543
+ return processed