| """ |
| Production-Grade Indonesian Conversational Language Model |
| Trained from scratch with Chain-of-Thought reasoning capability |
| |
| Architecture: Decoder-only transformer with MQA/GQA, RoPE, SwiGLU, RMSNorm |
| Target: 15M-30M parameters, optimized for Google Colab Free tier |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoTokenizer |
| import json |
| import math |
| import random |
| import numpy as np |
| from typing import Optional, Tuple, List, Dict |
| from dataclasses import dataclass |
| import warnings |
| import argparse |
| import os |
| warnings.filterwarnings('ignore') |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ModelConfig: |
| """Model architecture configuration""" |
| vocab_size: int = 30000 |
| hidden_size: int = 384 |
| num_layers: int = 12 |
| num_attention_heads: int = 6 |
| num_key_value_heads: int = 2 |
| intermediate_size: int = 1024 |
| max_position_embeddings: int = 2048 |
| rms_norm_eps: float = 1e-6 |
| rope_theta: float = 10000.0 |
| attention_dropout: float = 0.1 |
| residual_dropout: float = 0.1 |
| initializer_range: float = 0.02 |
| use_cache: bool = False |
| pad_token_id: int = 0 |
| bos_token_id: int = 1 |
| eos_token_id: int = 2 |
| tie_word_embeddings: bool = True |
|
|
| def __post_init__(self): |
| assert self.hidden_size % self.num_attention_heads == 0 |
| assert self.num_attention_heads % self.num_key_value_heads == 0 |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Training hyperparameters""" |
| dataset_path: str = "indonesian_cot_dataset.jsonl" |
| output_dir: str = "./indonesian_llm_checkpoints" |
|
|
| |
| num_epochs: int = 3 |
| batch_size: int = 4 |
| gradient_accumulation_steps: int = 12 |
| max_seq_length: int = 1024 |
|
|
| |
| learning_rate: float = 3e-4 |
| weight_decay: float = 0.01 |
| adam_beta1: float = 0.9 |
| adam_beta2: float = 0.95 |
| adam_epsilon: float = 1e-8 |
| max_grad_norm: float = 1.0 |
|
|
| |
| warmup_steps: int = 100 |
| lr_scheduler_type: str = "cosine" |
|
|
| |
| dropout: float = 0.1 |
|
|
| |
| use_fp16: bool = True |
|
|
| |
| seed: int = 42 |
|
|
| |
| logging_steps: int = 10 |
| eval_steps: int = 100 |
| save_steps: int = 500 |
|
|
| |
| curriculum_stages: List[int] = None |
|
|
| |
| skip_curriculum_stages: int = 2 |
| |
| plateau_patience: int = 3 |
| |
| plateau_factor: float = 0.5 |
| |
| plateau_min_delta: float = 0.02 |
|
|
| |
| ewc_lambda: float = 0.0 |
| ewc_samples: int = 2000 |
|
|
| def __post_init__(self): |
| if self.curriculum_stages is None: |
| self.curriculum_stages = [256, 512, 1024] |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary Positional Embeddings""" |
|
|
| def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
|
|
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| self._set_cos_sin_cache(max_position_embeddings) |
|
|
| def _set_cos_sin_cache(self, seq_len: int): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
| def forward(self, x: torch.Tensor, seq_len: int): |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len) |
|
|
| return ( |
| self.cos_cached[:seq_len].to(dtype=x.dtype), |
| self.sin_cached[:seq_len].to(dtype=x.dtype) |
| ) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin): |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, hidden_size: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| |
| |
| |
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
|
|
| assert self.hidden_size % self.num_heads == 0 |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| self.rotary_emb = RotaryEmbedding( |
| self.head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| base=self.rope_theta |
| ) |
| self.attention_dropout = nn.Dropout(config.attention_dropout) |
|
|
| def forward(self, hidden_states, attention_mask=None): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| cos, sin = self.rotary_emb(value_states, seq_len=q_len) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if self.num_key_value_groups > 1: |
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = self.attention_dropout(attn_weights) |
|
|
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLUMLP(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
| |
| |
|
|
| class DecoderLayer(nn.Module): |
| def __init__(self, config: ModelConfig, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.self_attn = GroupedQueryAttention(config) |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.mlp = SwiGLUMLP(config) |
| self.residual_dropout = nn.Dropout(config.residual_dropout) |
|
|
| def forward(self, hidden_states, attention_mask=None): |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask) |
| hidden_states = self.residual_dropout(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = self.residual_dropout(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| |
| |
| |
|
|
| class IndonesianLLM(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) |
| self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)]) |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| if config.tie_word_embeddings: |
| self.lm_head = None |
| else: |
| self.lm_head = nn.Linear(config.vocab_size, config.vocab_size, bias=False) |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def _prepare_attention_mask(self, attention_mask, input_shape, dtype): |
| batch_size, seq_length = input_shape |
| causal_mask = torch.triu( |
| torch.ones((seq_length, seq_length), dtype=torch.bool, device=attention_mask.device), |
| diagonal=1 |
| ) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length) |
|
|
| if attention_mask is not None: |
| expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length) |
| expanded_mask = expanded_mask.bool() |
| causal_mask = causal_mask | ~expanded_mask |
|
|
| causal_mask = torch.where(causal_mask, torch.finfo(dtype).min, 0.0) |
| return causal_mask |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| batch_size, seq_length = input_ids.shape |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
|
|
| hidden_states = self.embed_tokens(input_ids) |
| attention_mask = self._prepare_attention_mask( |
| attention_mask, (batch_size, seq_length), hidden_states.dtype |
| ) |
|
|
| for decoder_layer in self.layers: |
| hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if self.lm_head is not None: |
| logits = self.lm_head(hidden_states) |
| else: |
| logits = F.linear(hidden_states, self.embed_tokens.weight) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) |
|
|
| return {"loss": loss, "logits": logits} |
|
|
| def count_parameters(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| class IndonesianCoTDataset(Dataset): |
|
|
| def __init__( |
| self, |
| file_path: str, |
| tokenizer, |
| max_length: int = 1024, |
| cot_token: str = "<cot>", |
| end_cot_token: str = "</cot>", |
| use_cot: bool = True, |
| cot_ratio: float = 0.7 |
| ): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.cot_token = cot_token |
| self.end_cot_token = end_cot_token |
| self.use_cot = use_cot |
| self.cot_ratio = cot_ratio |
| self.samples = [] |
| self.skipped_count = 0 |
|
|
| self._load_data(file_path) |
|
|
| def _load_data(self, file_path: str): |
| print(f"Loading dataset from {file_path}...") |
|
|
| with open(file_path, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f, 1): |
| try: |
| if not line.strip(): |
| continue |
| data = json.loads(line) |
|
|
| if not all(key in data for key in ['input', 'cot', 'output']): |
| self.skipped_count += 1 |
| print(f"Warning: Line {line_num} missing required fields, skipping...") |
| continue |
|
|
| if not all(isinstance(data[key], str) for key in ['input', 'cot', 'output']): |
| self.skipped_count += 1 |
| print(f"Warning: Line {line_num} has invalid data types, skipping...") |
| continue |
|
|
| if not all(data[key].strip() for key in ['input', 'cot', 'output']): |
| self.skipped_count += 1 |
| print(f"Warning: Line {line_num} has empty fields, skipping...") |
| continue |
|
|
| self.samples.append(data) |
|
|
| except json.JSONDecodeError as e: |
| self.skipped_count += 1 |
| print(f"Warning: Line {line_num} is not valid JSON ({e}), skipping...") |
| continue |
| except Exception as e: |
| self.skipped_count += 1 |
| print(f"Warning: Line {line_num} caused error ({e}), skipping...") |
| continue |
|
|
| print(f"Loaded {len(self.samples)} valid samples") |
| print(f"Skipped {self.skipped_count} malformed rows") |
|
|
| if len(self.samples) == 0: |
| raise ValueError("No valid samples loaded from dataset!") |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
|
|
| |
| if self.use_cot: |
| if random.random() < self.cot_ratio: |
| prompt = f"{sample['input']} <cot>" |
| completion = f" {sample['cot']} {self.end_cot_token} {sample['output']}" |
| else: |
| prompt = f"{sample['input']}" |
| completion = f" {sample['output']}" |
| else: |
| prompt = f"{sample['input']}" |
| completion = f" {sample['output']}" |
|
|
| |
| prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=True) |
| prompt_len = len(prompt_ids) |
|
|
| |
| full_ids = self.tokenizer.encode( |
| prompt + completion, |
| max_length=self.max_length, |
| truncation=True, |
| padding=False, |
| return_tensors=None, |
| add_special_tokens=True, |
| ) |
|
|
| |
| |
| labels = [-100] * prompt_len + full_ids[prompt_len:] |
| |
| labels = labels[:len(full_ids)] |
|
|
| return { |
| 'input_ids': torch.tensor(full_ids, dtype=torch.long), |
| 'labels': torch.tensor(labels, dtype=torch.long), |
| 'length': len(full_ids), |
| 'cot_length': len(self.tokenizer.encode( |
| sample['cot'], add_special_tokens=False |
| )) if self.use_cot else 0, |
| } |
|
|
|
|
| def collate_fn_with_packing(batch, pad_token_id=0): |
| |
| |
| batch = sorted(batch, key=lambda x: x['length'], reverse=True) |
| max_length = max(item['length'] for item in batch) |
|
|
| input_ids_batch = [] |
| attention_mask_batch = [] |
| labels_batch = [] |
|
|
| for item in batch: |
| input_ids = item['input_ids'] |
| labels = item['labels'] |
| length = item['length'] |
| pad_len = max_length - length |
|
|
| input_ids_padded = F.pad(input_ids, (0, pad_len), value=pad_token_id) |
| labels_padded = F.pad(labels, (0, pad_len), value=-100) |
| attention_mask = torch.cat([ |
| torch.ones(length, dtype=torch.long), |
| torch.zeros(pad_len, dtype=torch.long) |
| ]) |
|
|
| input_ids_batch.append(input_ids_padded) |
| attention_mask_batch.append(attention_mask) |
| labels_batch.append(labels_padded) |
|
|
| return { |
| 'input_ids': torch.stack(input_ids_batch), |
| 'attention_mask': torch.stack(attention_mask_batch), |
| 'labels': torch.stack(labels_batch), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def create_curriculum_datasets(dataset, stages=[256, 512, 1024], use_simple=False, skip_stages=0): |
| """ |
| Build per-stage datasets. |
| |
| FIX (simple curriculum): skipped stages now use `continue` so they are |
| never built — no tokenizer.encode() calls wasted on stages that get thrown |
| away. The slice at the bottom is moved INSIDE the else-branch only, so |
| the simple-curriculum list is never accidentally emptied. |
| """ |
| datasets = [] |
|
|
| if use_simple: |
| for i, max_len in enumerate(stages): |
| |
| if i < skip_stages: |
| print(f"[SKIP] Curriculum stage {max_len}: skipped (not built)") |
| continue |
|
|
| filtered_samples = [ |
| s for s in dataset.samples |
| if len(dataset.tokenizer.encode( |
| f"{s['input']} {dataset.cot_token} {s['cot']} {dataset.end_cot_token} {s['output']}" |
| )) <= max_len |
| ] |
| stage_dataset = _build_stage_dataset(dataset, filtered_samples, max_len, dataset.cot_ratio) |
| datasets.append(stage_dataset) |
| print(f"Curriculum stage {max_len}: {len(filtered_samples)} samples") |
|
|
| |
|
|
| else: |
| print("\n" + "="*80) |
| print("3-STAGE REASONING CURRICULUM") |
| if skip_stages > 0: |
| print(f" (Skipping first {skip_stages} stage(s) — continue-train mode)") |
| print("="*80) |
|
|
| stage_configs = [ |
| { |
| 'name': 'Stage 1: Basic Q&A (short, no reasoning)', |
| 'max_len': 384, |
| 'cot_ratio': 0.0, |
| 'filter': lambda s: len(dataset.tokenizer.encode( |
| f"{s['input']} {s['output']}" |
| )) <= 384 |
| }, |
| { |
| 'name': 'Stage 2: Learning Reasoning (medium, 50% CoT)', |
| 'max_len': 512, |
| 'cot_ratio': 0.5, |
| 'filter': lambda s: True |
| }, |
| { |
| 'name': 'Stage 3: Full Reasoning (all, 100% CoT)', |
| 'max_len': 1024, |
| 'cot_ratio': 1.0, |
| 'filter': lambda s: True |
| } |
| ] |
|
|
| for idx, stage_config in enumerate(stage_configs): |
| filtered_samples = [s for s in dataset.samples if stage_config['filter'](s)] |
| stage_dataset = _build_stage_dataset( |
| dataset, filtered_samples, |
| stage_config['max_len'], stage_config['cot_ratio'] |
| ) |
| datasets.append(stage_dataset) |
|
|
| skipped = idx < skip_stages |
| prefix = " [SKIP] " if skipped else " " |
| print(f"{prefix}{stage_config['name']}") |
| print(f" {'(skipped)' if skipped else ''} Samples: {len(filtered_samples)}") |
| print(f" {'(skipped)' if skipped else ''} Max length: {stage_config['max_len']}") |
| print(f" {'(skipped)' if skipped else ''} CoT ratio: {stage_config['cot_ratio']*100:.0f}%") |
|
|
| print("="*80 + "\n") |
|
|
| |
| if skip_stages > 0: |
| datasets = datasets[skip_stages:] |
|
|
| return datasets |
|
|
|
|
| def _build_stage_dataset(base_dataset, samples, max_len, cot_ratio): |
| """Helper: create a shallow-copy stage dataset from a list of samples.""" |
| stage = IndonesianCoTDataset.__new__(IndonesianCoTDataset) |
| stage.tokenizer = base_dataset.tokenizer |
| stage.max_length = max_len |
| stage.cot_token = base_dataset.cot_token |
| stage.end_cot_token = base_dataset.end_cot_token |
| stage.samples = samples |
| stage.skipped_count = 0 |
| stage.use_cot = base_dataset.use_cot |
| stage.cot_ratio = cot_ratio |
| return stage |
|
|
|
|
| |
| |
| |
|
|
| def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): |
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| return max(0.0, float(num_training_steps - current_step) / |
| float(max(1, num_training_steps - num_warmup_steps))) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): |
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float( |
| max(1, num_training_steps - num_warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def get_continue_schedule(optimizer, num_training_steps: int, min_fraction: float = 0.1): |
| """ |
| Schedule for --continue-train without saved optimizer state. |
| Starts at target LR immediately and decays gently via cosine to |
| min_fraction x LR by the end of training. |
| """ |
| def lr_lambda(step): |
| progress = float(step) / float(max(1, num_training_steps)) |
| cosine_val = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) |
| return min_fraction + (1.0 - min_fraction) * cosine_val |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| class PlateauLRGuard: |
| """ |
| Wraps any LambdaLR scheduler and applies an extra multiplicative penalty |
| when perplexity has not improved for `patience` consecutive checks. |
| """ |
|
|
| def __init__(self, scheduler, patience=3, factor=0.5, min_delta=0.02): |
| self.scheduler = scheduler |
| self.patience = patience |
| self.factor = factor |
| self.min_delta = min_delta |
| self._best = float('inf') |
| self._no_improve = 0 |
| self._penalty = 1.0 |
|
|
| def step(self, perplexity: float): |
| relative_improvement = (self._best - perplexity) / max(self._best, 1e-8) |
|
|
| if relative_improvement > self.min_delta: |
| self._best = perplexity |
| self._no_improve = 0 |
| else: |
| self._no_improve += 1 |
|
|
| if self._no_improve >= self.patience: |
| self._penalty *= self.factor |
| self._no_improve = 0 |
| new_lr = self.scheduler.get_last_lr()[0] * self.factor |
| print(f"\n[PlateauLRGuard] No improvement for {self.patience} checks. " |
| f"Reducing LR by {self.factor:.2f}x -> {new_lr:.2e}") |
| for pg in self.scheduler.optimizer.param_groups: |
| pg['lr'] *= self.factor |
| return True |
|
|
| return False |
|
|
| def get_penalty(self): |
| return self._penalty |
|
|
|
|
| |
| |
| |
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| |
| |
| |
|
|
| class EWC: |
| """ |
| Elastic Weight Consolidation — prevents catastrophic forgetting during finetuning. |
| """ |
|
|
| def __init__(self, model, dataloader, device, n_samples: int = 2000): |
| self.device = device |
| self.n_samples = n_samples |
|
|
| self.params = { |
| n: p.clone().detach() |
| for n, p in model.named_parameters() |
| if p.requires_grad |
| } |
|
|
| self.fisher = self._compute_fisher(model, dataloader) |
|
|
| def _compute_fisher(self, model, dataloader): |
| fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad} |
|
|
| model.eval() |
| seen = 0 |
|
|
| for batch in dataloader: |
| if seen >= self.n_samples: |
| break |
|
|
| input_ids = batch["input_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| labels = batch["labels"].to(self.device) |
|
|
| model.zero_grad() |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| outputs["loss"].backward() |
|
|
| for n, p in model.named_parameters(): |
| if p.requires_grad and p.grad is not None: |
| fisher[n] += p.grad.detach().pow(2) |
|
|
| seen += input_ids.size(0) |
|
|
| for n in fisher: |
| fisher[n] /= max(1, seen) |
|
|
| model.train() |
| return fisher |
|
|
| def penalty(self, model) -> torch.Tensor: |
| loss = torch.tensor(0.0, device=self.device) |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.fisher: |
| loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum() |
| return loss * 0.5 |
|
|
|
|
| |
| |
| |
|
|
| def train_model( |
| model: IndonesianLLM, |
| train_dataset: IndonesianCoTDataset, |
| config: TrainingConfig, |
| device: torch.device, |
| use_simple_curriculum: bool = False, |
| is_continue: bool = False, |
| skip_curriculum_stages: int = 0, |
| ewc: "EWC | None" = None, |
| ): |
| """Main training loop.""" |
|
|
| print("\n" + "="*80) |
| print("TRAINING CONFIGURATION" + (" [CONTINUE MODE]" if is_continue else "")) |
| print("="*80) |
| print(f"Model parameters: {model.count_parameters():,}") |
| print(f"Dataset size: {len(train_dataset)}") |
| print(f"Batch size: {config.batch_size}") |
| print(f"Gradient accumulation steps: {config.gradient_accumulation_steps}") |
| print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") |
| print(f"Learning rate: {config.learning_rate}") |
| print(f"Max sequence length: {config.max_seq_length}") |
| print(f"Number of epochs: {config.num_epochs}") |
| print(f"Mixed precision: {config.use_fp16}") |
| if is_continue: |
| print(f"Skipping curriculum stages: {skip_curriculum_stages}") |
| print(f"Plateau patience: {config.plateau_patience}") |
| print(f"Plateau LR factor: {config.plateau_factor}") |
| if ewc is not None: |
| print(f"EWC lambda: {config.ewc_lambda} (anti-forgetting active)") |
| print("="*80 + "\n") |
|
|
| model.to(device) |
| model.train() |
|
|
| curriculum_datasets = create_curriculum_datasets( |
| train_dataset, |
| config.curriculum_stages, |
| use_simple=use_simple_curriculum, |
| skip_stages=skip_curriculum_stages, |
| ) |
|
|
| if len(curriculum_datasets) == 0: |
| print("ERROR: No curriculum stages to train on. Check --skip-stages value.") |
| return model |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.learning_rate, |
| betas=(config.adam_beta1, config.adam_beta2), |
| eps=config.adam_epsilon, |
| weight_decay=config.weight_decay |
| ) |
|
|
| |
| total_steps = 0 |
| for ds in curriculum_datasets: |
| steps_per_epoch = max(1, len(ds) // (config.batch_size * config.gradient_accumulation_steps)) |
| total_steps += steps_per_epoch * config.num_epochs |
|
|
| if total_steps == 0: |
| total_steps = 1 |
|
|
| |
| if is_continue: |
| scheduler = get_continue_schedule(optimizer, num_training_steps=total_steps) |
| plateau_guard = PlateauLRGuard( |
| scheduler, |
| patience=config.plateau_patience, |
| factor=config.plateau_factor, |
| min_delta=config.plateau_min_delta, |
| ) |
| print(f"[Scheduler] Continue-train: flat cosine decay from {config.learning_rate:.2e}") |
| else: |
| if config.lr_scheduler_type == "cosine": |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=config.warmup_steps, |
| num_training_steps=total_steps |
| ) |
| else: |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=config.warmup_steps, |
| num_training_steps=total_steps |
| ) |
| plateau_guard = None |
|
|
| scaler = torch.cuda.amp.GradScaler() if config.use_fp16 and torch.cuda.is_available() else None |
|
|
| global_step = 0 |
| perplexity_history = [] |
|
|
| for stage_idx, stage_dataset in enumerate(curriculum_datasets): |
| actual_stage = stage_idx + skip_curriculum_stages |
| print(f"\n{'='*80}") |
| print(f"CURRICULUM STAGE {actual_stage + 1}/{len(curriculum_datasets) + skip_curriculum_stages} " |
| f"(running {stage_idx + 1} of {len(curriculum_datasets)})") |
| print(f"Max sequence length: {stage_dataset.max_length}") |
| print(f"Samples: {len(stage_dataset)}") |
| if hasattr(stage_dataset, 'cot_ratio'): |
| print(f"CoT ratio: {stage_dataset.cot_ratio:.0%}") |
| print(f"{'='*80}\n") |
|
|
| dataloader = DataLoader( |
| stage_dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), |
| num_workers=0, |
| pin_memory=True if torch.cuda.is_available() else False |
| ) |
|
|
| for epoch in range(config.num_epochs): |
| print(f"\nEpoch {epoch + 1}/{config.num_epochs}") |
| epoch_loss = 0.0 |
| optimizer.zero_grad() |
|
|
| for step, batch in enumerate(dataloader): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| if scaler is not None: |
| with torch.cuda.amp.autocast(): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| task_loss = outputs['loss'] |
| if ewc is not None: |
| task_loss = task_loss + config.ewc_lambda * ewc.penalty(model) |
| loss = task_loss / config.gradient_accumulation_steps |
| scaler.scale(loss).backward() |
| else: |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| task_loss = outputs['loss'] |
| if ewc is not None: |
| task_loss = task_loss + config.ewc_lambda * ewc.penalty(model) |
| loss = task_loss / config.gradient_accumulation_steps |
| loss.backward() |
|
|
| epoch_loss += loss.item() * config.gradient_accumulation_steps |
|
|
| if (step + 1) % config.gradient_accumulation_steps == 0: |
| if scaler is not None: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
|
|
| if scaler is not None: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
|
|
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| if global_step % config.logging_steps == 0: |
| avg_loss = epoch_loss / (step + 1) |
| current_lr = scheduler.get_last_lr()[0] |
| print(f"Step {global_step:>6} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
|
|
| avg_epoch_loss = epoch_loss / max(1, len(dataloader)) |
| perplexity = math.exp(min(avg_epoch_loss, 20)) |
| perplexity_history.append(perplexity) |
|
|
| print(f"Epoch {epoch + 1} completed | Avg Loss: {avg_epoch_loss:.4f} " |
| f"| Perplexity: {perplexity:.2f}") |
|
|
| if plateau_guard is not None: |
| plateau_guard.step(perplexity) |
|
|
| print("\n" + "="*80) |
| print("TRAINING COMPLETED") |
| if perplexity_history: |
| print(f"Final perplexity: {perplexity_history[-1]:.2f}") |
| print(f"Best perplexity: {min(perplexity_history):.2f}") |
| print("="*80 + "\n") |
|
|
| return model |
|
|
|
|
| |
| |
| |
|
|
| def evaluate_model(model, dataset, device, batch_size=4): |
| model.eval() |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), |
| num_workers=0 |
| ) |
|
|
| total_loss = 0.0 |
| total_samples = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| total_loss += outputs['loss'].item() * input_ids.size(0) |
| total_samples += input_ids.size(0) |
|
|
| avg_loss = total_loss / max(1, total_samples) |
| perplexity = math.exp(min(avg_loss, 20)) |
|
|
| print(f"\nEvaluation Results:") |
| print(f"Average Loss: {avg_loss:.4f}") |
| print(f"Perplexity: {perplexity:.2f}") |
|
|
| if perplexity < 5.0: |
| print("Status: Excellent") |
| elif perplexity < 10.0: |
| print("Status: Good") |
| elif perplexity < 20.0: |
| print("Status: Fair — try more epochs or lower LR") |
| else: |
| print("Status: Poor — check data quality or model config") |
|
|
| if avg_loss < 0.5: |
| print("Warning: Very low loss might indicate overfitting") |
|
|
| return {"loss": avg_loss, "perplexity": perplexity} |
|
|
|
|
| |
| |
| |
|
|
| def generate_text( |
| model, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int = 256, |
| temperature: float = 0.7, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| device: torch.device = torch.device('cpu') |
| ): |
| model.eval() |
|
|
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
| generated_ids = input_ids.clone() |
|
|
| eos_token_id = tokenizer.eos_token_id |
| if eos_token_id is None: |
| eos_token_id = tokenizer.sep_token_id |
| if eos_token_id is None: |
| eos_token_id = 2 |
|
|
| stop_tokens = {eos_token_id, tokenizer.pad_token_id} |
| if tokenizer.sep_token_id is not None: |
| stop_tokens.add(tokenizer.sep_token_id) |
|
|
| repetition_buffer = [] |
|
|
| with torch.no_grad(): |
| for step in range(max_new_tokens): |
| outputs = model(input_ids=generated_ids) |
| logits = outputs['logits'] |
| next_token_logits = logits[:, -1, :] / max(temperature, 0.1) |
|
|
| if len(repetition_buffer) > 10: |
| for token in set(repetition_buffer[-10:]): |
| if token in stop_tokens: |
| continue |
| next_token_logits[0, token] -= 2.0 |
|
|
| if top_k > 0: |
| indices_to_remove = next_token_logits < torch.topk( |
| next_token_logits, min(top_k, next_token_logits.size(-1)) |
| )[0][..., -1, None] |
| next_token_logits[indices_to_remove] = float('-inf') |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter( |
| 1, sorted_indices, sorted_indices_to_remove) |
| next_token_logits[indices_to_remove] = float('-inf') |
|
|
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| if next_token.item() in stop_tokens: |
| break |
|
|
| repetition_buffer.append(next_token.item()) |
| if len(repetition_buffer) > 20: |
| repetition_buffer.pop(0) |
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
| max_ctx = model.config.max_position_embeddings |
| if generated_ids.size(1) > max_ctx: |
| generated_ids = generated_ids[:, -max_ctx:] |
|
|
| if step > 10: |
| decoded = tokenizer.decode( |
| generated_ids[0][input_ids.size(1):], skip_special_tokens=False) |
| if '\n\n' in decoded or 'User:' in decoded or 'Assistant:' in decoded[-20:]: |
| break |
|
|
| return tokenizer.decode(generated_ids[0], skip_special_tokens=False) |
|
|
|
|
| |
| |
| |
|
|
| def _clean_response(response: str) -> str: |
| import re |
|
|
| if "<cot>" in response and "</cot>" in response: |
| response = response.split("</cot>", 1)[-1] |
| elif "<cot>" in response: |
| response = response.split("<cot>", 1)[0] |
|
|
| response = re.sub(r'\[\w+\]', '', response) |
| response = re.sub(r'<[^>]+>', '', response) |
|
|
| for marker in [ |
| "user :", "user:", "User :", "User:", |
| "assistant :", "assistant:", "Assistant :", "Assistant:", |
| "memahami permintaan", "jawaban singkat", "penjelasan harus", |
| "\n\n", |
| ]: |
| if marker in response: |
| response = response.split(marker)[0] |
|
|
| response = re.sub(r'^[\s:!,\.\-|\[\]]+', '', response) |
| response = re.sub(r' {2,}', ' ', response).strip() |
| return response |
|
|
|
|
| def _extract_thinking(raw: str) -> tuple: |
| import re |
|
|
| raw = re.sub(r'\[\w+\]', '', raw) |
|
|
| if "</cot>" in raw: |
| thinking_raw, answer_raw = raw.split("</cot>", 1) |
| else: |
| thinking_raw, answer_raw = raw, "" |
|
|
| thinking = thinking_raw.strip() |
| for marker in ["user :", "user:", "memahami permintaan", "\n\n"]: |
| if marker in thinking: |
| thinking = thinking.split(marker)[0] |
| thinking = re.sub(r'<[^>]+>', '', thinking).strip() |
|
|
| answer = _clean_response(answer_raw) |
|
|
| return thinking, answer |
|
|
|
|
| def interactive_chat(model, tokenizer, device, |
| system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia."): |
| print("\n" + "="*80) |
| print("INDONESIAN LLM — INTERACTIVE CHAT") |
| print("="*80) |
| print("Commands: 'exit'/'quit' | 'clear' | 'think' (toggle reasoning display)") |
| print(f"Persona : {system_prompt}") |
| print("="*80 + "\n") |
|
|
| model.eval() |
| show_thinking = False |
|
|
| import time |
| torch.manual_seed(int(time.time()) % 100000) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(int(time.time()) % 100000) |
|
|
| while True: |
| try: |
| user_input = input("\nYou: ").strip() |
| if not user_input: |
| continue |
| if user_input.lower() in ['exit', 'quit', 'keluar']: |
| print("\nGoodbye!") |
| break |
| if user_input.lower() in ['clear', 'bersihkan']: |
| print("\nConversation cleared") |
| continue |
| if user_input.lower() == 'think': |
| show_thinking = not show_thinking |
| print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}") |
| continue |
|
|
| prompt = f"{user_input} <cot>" |
| max_tokens = 250 |
|
|
| print("\nA:", end=" ", flush=True) |
| full_response = generate_text( |
| model=model, tokenizer=tokenizer, prompt=prompt, |
| max_new_tokens=max_tokens, temperature=0.9, |
| top_k=50, top_p=0.92, device=device |
| ) |
| response = full_response[len(prompt):].strip() |
|
|
| thinking, answer = _extract_thinking(response) |
|
|
| if show_thinking and thinking: |
| print(f"[Thinking: {thinking}]") |
|
|
| final = answer if answer else _clean_response(response) |
| if not final or len(final) < 3: |
| final = "Maaf, saya tidak mengerti. Bisa diulang?" |
| print(final) |
|
|
| except KeyboardInterrupt: |
| print("\n\nChat interrupted") |
| break |
| except Exception as e: |
| print(f"\nError: {e}") |
|
|
|
|
| |
| |
| |
|
|
| def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 20, verbose: bool = True): |
| import time |
|
|
| if dataset_path is None or not os.path.exists(dataset_path): |
| print(f"Dataset not found at: {dataset_path}") |
| print("Pass --dataset path/to/your.jsonl to benchmark against your data.") |
| return |
|
|
| all_samples = [] |
| with open(dataset_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| d = json.loads(line) |
| if all(k in d for k in ['input', 'output']) and d['input'].strip() and d['output'].strip(): |
| all_samples.append(d) |
| except Exception: |
| continue |
|
|
| if not all_samples: |
| print("No valid samples found in dataset.") |
| return |
|
|
| random.seed(int(time.time())) |
| samples = random.sample(all_samples, min(n, len(all_samples))) |
|
|
| model.eval() |
| torch.manual_seed(42) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(42) |
|
|
| print("\n" + "="*80) |
| print(f"BENCHMARK ({len(samples)} random samples from dataset)") |
| print("="*80) |
|
|
| results = [] |
| ppl_loss = 0.0 |
| ppl_toks = 0 |
|
|
| for i, sample in enumerate(samples): |
| inp = sample['input'].strip() |
| expected = sample['output'].strip().lower() |
|
|
| prompt = f"{inp} <cot>" |
| full = generate_text( |
| model=model, tokenizer=tokenizer, prompt=prompt, |
| max_new_tokens=150, temperature=0.3, |
| top_k=20, top_p=0.9, device=device |
| ) |
| raw = full[len(prompt):].strip() |
| _, answer = _extract_thinking(raw) |
| answer_lower = answer.lower() |
|
|
| passed = expected in answer_lower |
|
|
| if not passed: |
| exp_tokens = set(expected.split()) |
| ans_tokens = set(answer_lower.split()) |
| if exp_tokens: |
| overlap = len(exp_tokens & ans_tokens) / len(exp_tokens) |
| passed = overlap >= 0.5 |
|
|
| results.append(passed) |
|
|
| with torch.no_grad(): |
| ids = tokenizer.encode( |
| f"{inp} <cot> {sample['output']}", |
| return_tensors="pt" |
| ).to(device) |
| if ids.size(1) >= 2: |
| out = model(input_ids=ids, labels=ids) |
| toks = ids.size(1) - 1 |
| ppl_loss += out["loss"].item() * toks |
| ppl_toks += toks |
|
|
| if verbose: |
| status = "PASS" if passed else "FAIL" |
| print(f" [{status}] {inp[:60]}") |
| print(f" Expected : {sample['output'][:80]}") |
| print(f" Got : {answer[:80] if answer else '(no answer)'}") |
|
|
| total_pass = sum(results) |
| total = len(results) |
| overall = total_pass / total * 100 |
| bar = "█" * int(overall / 10) + "░" * (10 - int(overall / 10)) |
| ppl = math.exp(min(ppl_loss / max(1, ppl_toks), 20)) |
|
|
| print("\n" + "-"*80) |
| print(f" SCORE {total_pass}/{total} ({overall:.1f}%) {bar}") |
| print(f" PERPLEXITY {ppl:.2f} (lower = better)") |
| print("="*80 + "\n") |
|
|
| return {"score": overall, "pass": total_pass, "total": total, "perplexity": ppl} |
|
|
|
|
| |
| |
| |
|
|
| def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, save_path: str, use_fp16: bool = True): |
| os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True) |
|
|
| state_dict = model.state_dict() |
| if use_fp16: |
| state_dict = {k: v.half() if v.dtype == torch.float32 else v |
| for k, v in state_dict.items()} |
|
|
| torch.save({ |
| 'model_state_dict': state_dict, |
| 'config': config, |
| 'tokenizer_name': tokenizer_name, |
| 'model_params': model.count_parameters(), |
| 'dtype': 'fp16' if use_fp16 else 'fp32', |
| }, save_path) |
|
|
| size_mb = os.path.getsize(save_path) / 1e6 |
| print(f"\nModel saved to: {save_path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB)") |
| print(f"Parameters: {model.count_parameters():,}") |
|
|
|
|
| def load_model(load_path: str, device: torch.device): |
| if not os.path.exists(load_path): |
| raise FileNotFoundError(f"Model checkpoint not found: {load_path}") |
|
|
| print(f"Loading model from: {load_path}") |
| checkpoint = torch.load(load_path, map_location=device, weights_only=False) |
|
|
| config = checkpoint['config'] |
| tokenizer_name = checkpoint['tokenizer_name'] |
| dtype = checkpoint.get('dtype', 'fp32') |
|
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]}) |
|
|
| model = IndonesianLLM(config) |
|
|
| state_dict = checkpoint['model_state_dict'] |
| if dtype == 'fp16': |
| state_dict = {k: v.float() if v.dtype == torch.float16 else v |
| for k, v in state_dict.items()} |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
|
|
| extra = {'training_metadata': checkpoint.get('training_metadata', None)} |
|
|
| size_mb = os.path.getsize(load_path) / 1e6 |
| print(f"Model loaded ({dtype}, {size_mb:.1f} MB) | " |
| f"Parameters: {checkpoint.get('model_params', model.count_parameters()):,}") |
|
|
| return model, tokenizer, config, extra |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Indonesian Conversational LLM — Train, Chat, Finetune, or Continue") |
|
|
| parser.add_argument('--train', action='store_true', help='Train model from scratch') |
| parser.add_argument('--chat', action='store_true', help='Interactive chat mode') |
| parser.add_argument('--finetune', action='store_true', help='Fine-tune on NEW data (lr/10)') |
| parser.add_argument('--continue-train', action='store_true', |
| help='Continue training on SAME data with proper LR re-warmup') |
| parser.add_argument('--inspect-data', action='store_true', help='Inspect dataset quality') |
| parser.add_argument('--benchmark', action='store_true', help='Run benchmark suite on a saved model') |
| parser.add_argument('--no-eval', action='store_true', help='Skip evaluation after training') |
| parser.add_argument('--grad-accum', type=int, default=None, |
| help='Override gradient accumulation steps') |
| parser.add_argument('--ewc-lambda', type=float, default=5000.0, |
| help='EWC penalty strength for --finetune (default 5000). 0 = disabled.') |
| parser.add_argument('--ewc-samples', type=int, default=2000, |
| help='Samples used to estimate Fisher Information for EWC (default 2000).') |
| parser.add_argument('--no-ewc', action='store_true', |
| help='Disable EWC during finetuning.') |
|
|
| parser.add_argument('--dataset', type=str, default='indonesian_cot_dataset.jsonl') |
| parser.add_argument('--model', type=str, default='indonesian_llm_model.pt') |
| parser.add_argument('--epochs', type=int, default=5) |
| parser.add_argument('--batch-size', type=int, default=4) |
| parser.add_argument('--lr', type=float, default=2e-4) |
| parser.add_argument('--max-length', type=int, default=512, |
| help='Max sequence length. Dataset max is 448 so 512 is optimal.') |
| parser.add_argument('--hidden-size', type=int, default=320) |
| parser.add_argument('--num-layers', type=int, default=16) |
| parser.add_argument('--num-heads', type=int, default=8) |
| parser.add_argument('--num-kv-heads', type=int, default=2) |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--save-fp16', action='store_true', default=True) |
| parser.add_argument('--save-fp32', action='store_true') |
| parser.add_argument('--no-cot', action='store_true') |
| parser.add_argument('--use-cot', action='store_true', default=True) |
| parser.add_argument('--simple-curriculum', action='store_true') |
| parser.add_argument('--cot-ratio', type=float, default=1.0) |
| parser.add_argument('--system-prompt', type=str, |
| default='Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.') |
|
|
| parser.add_argument('--rewarm-steps', type=int, default=150) |
| parser.add_argument('--skip-stages', type=int, default=2, |
| help='Curriculum stages to skip in continue-train (default 2).') |
| parser.add_argument('--plateau-patience', type=int, default=3) |
| parser.add_argument('--no-restore-optimizer', action='store_true') |
|
|
| args = parser.parse_args() |
|
|
| if not any([args.train, args.chat, args.finetune, args.continue_train, |
| args.inspect_data, args.benchmark]): |
| parser.print_help() |
| print("\nError: Specify a mode: --train, --chat, --finetune, --continue-train, " |
| "--inspect-data, or --benchmark") |
| return |
|
|
| save_fp16 = not args.save_fp32 |
| use_cot_training = not args.no_cot |
|
|
| set_seed(args.seed) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"\nDevice: {device}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n") |
|
|
| |
| |
| |
| if args.inspect_data: |
| print("\nInspecting dataset...") |
| print("="*80) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased") |
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]}) |
|
|
| dataset = IndonesianCoTDataset( |
| file_path=args.dataset, tokenizer=tokenizer, |
| max_length=args.max_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio |
| ) |
|
|
| print(f"\nDataset Statistics:") |
| print(f"Total samples: {len(dataset)}") |
| print(f"Skipped samples: {dataset.skipped_count}") |
|
|
| lengths = [] |
| cot_lengths = [] |
| for i in range(min(len(dataset), 1000)): |
| sample = dataset[i] |
| lengths.append(sample['length']) |
| cot_lengths.append(sample['cot_length']) |
|
|
| print(f"\nSequence Length Stats:") |
| print(f" Min: {min(lengths)}") |
| print(f" Max: {max(lengths)}") |
| print(f" Avg: {sum(lengths)/len(lengths):.1f}") |
| print(f" Median: {sorted(lengths)[len(lengths)//2]}") |
|
|
| print(f"\nCoT Length Stats:") |
| print(f" Min: {min(cot_lengths)}") |
| print(f" Max: {max(cot_lengths)}") |
| print(f" Avg: {sum(cot_lengths)/len(cot_lengths):.1f}") |
|
|
| long_cot = sum(1 for x in cot_lengths if x > 50) |
| print(f" Samples with long CoT (>50 tokens): {long_cot} ({long_cot/len(cot_lengths)*100:.1f}%)") |
|
|
| print(f"\n{'='*80}") |
| print("Sample Examples (first 5):") |
| print("="*80) |
| for i in range(min(5, len(dataset.samples))): |
| s = dataset.samples[i] |
| print(f"\n--- Sample {i+1} ---") |
| print(f"Input: {s['input'][:100]}...") |
| print(f"CoT: {s['cot'][:150]}...") |
| print(f"Output: {s['output'][:100]}...") |
|
|
| print("\n" + "="*80) |
| print("Dataset Quality Checks:") |
| issues = [] |
| for i, sample in enumerate(dataset.samples[:100]): |
| if len(sample['input']) < 10: |
| issues.append(f"Sample {i}: Input too short") |
| if len(sample['output']) < 10: |
| issues.append(f"Sample {i}: Output too short") |
| if len(sample['cot']) < 20: |
| issues.append(f"Sample {i}: CoT too short") |
| if sample['input'].lower() == sample['output'].lower(): |
| issues.append(f"Sample {i}: Input == Output (copy)") |
|
|
| if issues: |
| print(f"\nFound {len(issues)} potential issues in first 100 samples:") |
| for issue in issues[:10]: |
| print(f" - {issue}") |
| if len(issues) > 10: |
| print(f" ... and {len(issues)-10} more") |
| else: |
| print("\nNo obvious issues detected in first 100 samples") |
| print("\n" + "="*80) |
| return |
|
|
| |
| |
| |
| if args.chat: |
| print("\nStarting CHAT mode...") |
| if not os.path.exists(args.model): |
| print(f"Error: Model checkpoint not found: {args.model}") |
| return |
| model, tokenizer, config, _ = load_model(args.model, device) |
| interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt) |
| return |
|
|
| |
| |
| |
| if args.benchmark: |
| print("\nRunning benchmark...") |
| if not os.path.exists(args.model): |
| print(f"Error: Model not found: {args.model}") |
| return |
| model, tokenizer, config, _ = load_model(args.model, device) |
| run_benchmark(model, tokenizer, device, dataset_path=args.dataset) |
| return |
|
|
| |
| |
| |
| if args.train: |
| print("\nStarting TRAINING mode (from scratch)...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased") |
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]}) |
|
|
| model_config = ModelConfig( |
| vocab_size=len(tokenizer), |
| hidden_size=args.hidden_size, |
| num_layers=args.num_layers, |
| num_attention_heads=args.num_heads, |
| num_key_value_heads=args.num_kv_heads, |
| intermediate_size=args.hidden_size * 3, |
| max_position_embeddings=2048, |
| attention_dropout=0.1, |
| residual_dropout=0.1, |
| tie_word_embeddings=True |
| ) |
|
|
| model = IndonesianLLM(model_config) |
| print(f"Model parameters: {model.count_parameters():,}") |
|
|
| _ga = args.grad_accum if args.grad_accum else 32 |
| train_config = TrainingConfig( |
| dataset_path=args.dataset, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| gradient_accumulation_steps=_ga, |
| max_seq_length=args.max_length, |
| learning_rate=args.lr, |
| warmup_steps=500, |
| use_fp16=torch.cuda.is_available(), |
| curriculum_stages=[128, 256, args.max_length] |
| ) |
|
|
| dataset = IndonesianCoTDataset( |
| file_path=train_config.dataset_path, tokenizer=tokenizer, |
| max_length=train_config.max_seq_length, use_cot=use_cot_training, |
| cot_ratio=args.cot_ratio |
| ) |
|
|
| model = train_model( |
| model, dataset, train_config, device, |
| use_simple_curriculum=args.simple_curriculum |
| ) |
|
|
| if not args.no_eval: |
| evaluate_model(model, dataset, device) |
|
|
| save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16) |
|
|
| test_prompts = [ |
| "Berapa hasil dari 1+1?", |
| "Jelaskan cara kerja komputer", |
| "Bagaimana cara membuat kopi yang enak?" |
| ] |
| print("\n" + "="*80) |
| print("GENERATION TEST") |
| print("="*80 + "\n") |
| for prompt in test_prompts: |
| print(f"Prompt: {prompt}") |
| generated = generate_text(model, tokenizer, prompt, max_new_tokens=150, device=device) |
| print(f"Generated: {generated}\n") |
| print("-" * 80 + "\n") |
|
|
| print(f"\nTo chat: python {__file__} --chat --model {args.model}") |
|
|
| |
| |
| |
| if args.finetune: |
| print("\nStarting FINETUNE mode (for NEW data)...") |
|
|
| if not os.path.exists(args.model): |
| print(f"Error: Model checkpoint not found: {args.model}") |
| return |
|
|
| model, tokenizer, model_config, extra = load_model(args.model, device) |
|
|
| _ga = args.grad_accum if args.grad_accum else 32 |
| train_config = TrainingConfig( |
| dataset_path=args.dataset, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| gradient_accumulation_steps=_ga, |
| max_seq_length=args.max_length, |
| learning_rate=args.lr / 10, |
| warmup_steps=100, |
| use_fp16=torch.cuda.is_available(), |
| curriculum_stages=[128, 256, args.max_length] |
| ) |
|
|
| dataset = IndonesianCoTDataset( |
| file_path=train_config.dataset_path, tokenizer=tokenizer, |
| max_length=train_config.max_seq_length, use_cot=use_cot_training, |
| cot_ratio=args.cot_ratio |
| ) |
|
|
| print(f"\nStarting fine-tuning with LR={train_config.learning_rate:.2e}...") |
|
|
| ewc_obj = None |
| if not args.no_ewc and args.ewc_lambda > 0: |
| print(f"\nComputing EWC Fisher Information " |
| f"(lambda={args.ewc_lambda}, samples={args.ewc_samples})...") |
| print(" This takes ~1-2 min on T4. Prevents forgetting old training data.") |
| old_dataset = IndonesianCoTDataset( |
| file_path=args.dataset, tokenizer=tokenizer, |
| max_length=train_config.max_seq_length, use_cot=use_cot_training, |
| cot_ratio=args.cot_ratio |
| ) |
| old_loader = DataLoader( |
| old_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), |
| num_workers=0 |
| ) |
| train_config.ewc_lambda = args.ewc_lambda |
| train_config.ewc_samples = args.ewc_samples |
| ewc_obj = EWC(model, old_loader, device, n_samples=args.ewc_samples) |
| print(f" Fisher computed. EWC penalty will be added during training.") |
| else: |
| print(" EWC disabled — model may forget previous training.") |
|
|
| model = train_model( |
| model, dataset, train_config, device, |
| use_simple_curriculum=args.simple_curriculum, |
| ewc=ewc_obj, |
| ) |
|
|
| if not args.no_eval: |
| evaluate_model(model, dataset, device) |
|
|
| finetuned_path = args.model.replace('.pt', '_finetuned.pt') |
| save_model(model, model_config, "indolem/indobert-base-uncased", |
| finetuned_path, use_fp16=save_fp16) |
|
|
| print(f"\nFine-tuning completed. Model saved to: {finetuned_path}") |
| print(f"To chat: python {__file__} --chat --model {finetuned_path}") |
|
|
| |
| |
| |
| if args.continue_train: |
| print("\nStarting CONTINUE-TRAIN mode...") |
| print("="*80) |
| print("NOTE: For a 15-30M param model on Indonesian CoT, perplexity 5-15 is normal.") |
| print("Key improvements:") |
| print(" 1. Early curriculum stages skipped → no wasted time on short sequences") |
| print(" 2. Plateau LR reduction → auto-halves LR if perplexity stalls") |
| print(" 3. Micro LR on cold Adam → no destructive spike at restart") |
| print("="*80) |
|
|
| if not os.path.exists(args.model): |
| print(f"Error: Model checkpoint not found: {args.model}") |
| return |
|
|
| model, tokenizer, model_config, extra = load_model(args.model, device) |
|
|
| effective_lr = args.lr * 0.05 |
| print(f"\nContinue-train LR: {args.lr:.2e} x 0.05 = {effective_lr:.2e}") |
| print(f" Adam cold-starts — micro LR prevents overshooting the trained minimum.") |
|
|
| curriculum_stages = [192, 320, args.max_length] |
| if args.simple_curriculum: |
| effective_skip = len(curriculum_stages) - 1 |
| else: |
| effective_skip = args.skip_stages |
| print(f"Skipping first {effective_skip} curriculum stage(s) — " |
| f"training on full-length data only.") |
|
|
| _ga = args.grad_accum if args.grad_accum else 32 |
| train_config = TrainingConfig( |
| dataset_path=args.dataset, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size, |
| gradient_accumulation_steps=_ga, |
| max_seq_length=args.max_length, |
| learning_rate=effective_lr, |
| warmup_steps=0, |
| use_fp16=torch.cuda.is_available(), |
| curriculum_stages=curriculum_stages, |
| skip_curriculum_stages=effective_skip, |
| plateau_patience=2, |
| plateau_factor=0.5, |
| plateau_min_delta=0.02, |
| ) |
|
|
| dataset = IndonesianCoTDataset( |
| file_path=train_config.dataset_path, tokenizer=tokenizer, |
| max_length=train_config.max_seq_length, use_cot=use_cot_training, |
| cot_ratio=args.cot_ratio |
| ) |
|
|
| model = train_model( |
| model, dataset, train_config, device, |
| use_simple_curriculum=args.simple_curriculum, |
| is_continue=True, |
| skip_curriculum_stages=effective_skip, |
| ) |
|
|
| if not args.no_eval: |
| print("\nEvaluating continued model...") |
| evaluate_model(model, dataset, device) |
|
|
| save_model(model, model_config, "indolem/indobert-base-uncased", |
| args.model, use_fp16=save_fp16) |
|
|
| print(f"\nContinued training completed.") |
| print(f"Model saved to: {args.model}") |
| print(f"To chat: python {__file__} --chat --model {args.model}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |