File size: 5,195 Bytes
b2b9f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os, logging, torch, transformers
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from transformers import TrainingArguments, Trainer, TrainerCallback, set_seed
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from model.architecture import CodeLLM, CodeLLMConfig
from model.tokenizer import get_gpt2_tokenizer_for_code, load_tokenizer
from data.dataset import TheStackStreamDataset, CodeCollator

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class TrainConfig:
    model_config: CodeLLMConfig = field(default_factory=CodeLLMConfig)
    tokenizer_path: Optional[str] = None
    languages: list = field(default_factory=lambda: ["python", "javascript", "typescript", "rust"])
    max_length: int = 2048
    fim_rate: float = 0.5
    output_dir: str = "./checkpoints"
    num_train_steps: int = 100_000
    per_device_batch_size: int = 4
    gradient_accumulation_steps: int = 8
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    max_grad_norm: float = 1.0
    warmup_steps: int = 2000
    lr_scheduler_type: str = "cosine"
    bf16: bool = True
    fp16: bool = False
    gradient_checkpointing: bool = True
    dataloader_num_workers: int = 4
    logging_steps: int = 50
    save_steps: int = 1000
    push_to_hub: bool = True
    hub_model_id: str = "devoppro/codellm-125m"   # ← your HF username
    seed: int = 42

class CodeLLMForTrainer(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids=None, labels=None, attention_mask=None, **kwargs):
        out = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
        return transformers.modeling_outputs.CausalLMOutputWithPast(
            loss=out["loss"], logits=out["logits"],
        )

    def gradient_checkpointing_enable(self, **kwargs):
        for block in self.model.transformer.h:
            block.use_checkpoint = True

    @property
    def config(self):
        class FakeConfig:
            is_encoder_decoder = False
            model_type = "codellm"
        return FakeConfig()

class GenerateSampleCallback(TrainerCallback):
    def __init__(self, model, tokenizer, prompts):
        self.model = model
        self.tokenizer = tokenizer
        self.prompts = prompts

    def on_evaluate(self, args, state, control, **kwargs):
        self.model.eval()
        device = next(self.model.parameters()).device
        print("\n" + "="*60)
        for prompt in self.prompts:
            ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
            out = self.model.generate(ids, max_new_tokens=128, temperature=0.8)
            text = self.tokenizer.decode(out[0], skip_special_tokens=True)
            print(f"\n[PROMPT] {prompt}\n[OUTPUT] {text[len(prompt):]}")
        print("="*60 + "\n")

def train(cfg: TrainConfig):
    set_seed(cfg.seed)
    if cfg.tokenizer_path and Path(cfg.tokenizer_path).exists():
        tokenizer = load_tokenizer(cfg.tokenizer_path)
    else:
        tokenizer = get_gpt2_tokenizer_for_code()
    cfg.model_config.vocab_size = len(tokenizer)
    model_core = CodeLLM(cfg.model_config)
    model = CodeLLMForTrainer(model_core)
    if cfg.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    train_dataset = TheStackStreamDataset(
        tokenizer=tokenizer, max_length=cfg.max_length,
        languages=cfg.languages, fim_rate=cfg.fim_rate,
    )
    collator = CodeCollator(pad_token_id=tokenizer.pad_token_id or 0, max_length=cfg.max_length)
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        max_steps=cfg.num_train_steps,
        per_device_train_batch_size=cfg.per_device_batch_size,
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
        max_grad_norm=cfg.max_grad_norm,
        warmup_steps=cfg.warmup_steps,
        lr_scheduler_type=cfg.lr_scheduler_type,
        bf16=cfg.bf16, fp16=cfg.fp16,
        dataloader_num_workers=cfg.dataloader_num_workers,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        save_total_limit=3,
        push_to_hub=cfg.push_to_hub,
        hub_model_id=cfg.hub_model_id if cfg.push_to_hub else None,
        report_to=["tensorboard"],
        remove_unused_columns=False,
        prediction_loss_only=True,
        optim="adamw_torch_fused",
    )
    trainer = Trainer(
        model=model, args=training_args,
        train_dataset=train_dataset, data_collator=collator,
        callbacks=[GenerateSampleCallback(model_core, tokenizer, [
            "<|python|>def fibonacci(n):",
            "<|javascript|>async function fetchData(url) {",
        ])],
    )
    trainer.train()
    output_path = Path(cfg.output_dir) / "final"
    output_path.mkdir(parents=True, exist_ok=True)
    torch.save(model_core.state_dict(), output_path / "pytorch_model.bin")
    tokenizer.save_pretrained(output_path)
    if cfg.push_to_hub:
        trainer.push_to_hub()

if __name__ == "__main__":
    train(TrainConfig())