""" Production-Grade Indonesian Conversational Language Model Trained from scratch with Chain-of-Thought reasoning capability Architecture: Decoder-only transformer with MQA/GQA, RoPE, SwiGLU, RMSNorm Target: 15M-30M parameters, optimized for Google Colab Free tier """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer import json import math import random import numpy as np from typing import Optional, Tuple, List, Dict from dataclasses import dataclass import warnings import argparse import os warnings.filterwarnings('ignore') # ============================================================================ # CONFIGURATION # ============================================================================ @dataclass class ModelConfig: """Model architecture configuration""" vocab_size: int = 30000 hidden_size: int = 384 num_layers: int = 12 num_attention_heads: int = 6 num_key_value_heads: int = 2 # GQA: 2 KV heads, MQA: 1 KV head intermediate_size: int = 1024 max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 attention_dropout: float = 0.1 residual_dropout: float = 0.1 initializer_range: float = 0.02 use_cache: bool = False pad_token_id: int = 0 bos_token_id: int = 1 eos_token_id: int = 2 tie_word_embeddings: bool = True def __post_init__(self): assert self.hidden_size % self.num_attention_heads == 0 assert self.num_attention_heads % self.num_key_value_heads == 0 @dataclass class TrainingConfig: """Training hyperparameters""" dataset_path: str = "indonesian_cot_dataset.jsonl" output_dir: str = "./indonesian_llm_checkpoints" # Training num_epochs: int = 3 batch_size: int = 4 gradient_accumulation_steps: int = 12 max_seq_length: int = 1024 # Optimization learning_rate: float = 3e-4 weight_decay: float = 0.01 adam_beta1: float = 0.9 adam_beta2: float = 0.95 adam_epsilon: float = 1e-8 max_grad_norm: float = 1.0 # Scheduler warmup_steps: int = 100 lr_scheduler_type: str = "cosine" # Regularization dropout: float = 0.1 # Mixed precision use_fp16: bool = True # Reproducibility seed: int = 42 # Logging logging_steps: int = 10 eval_steps: int = 100 save_steps: int = 500 # Curriculum learning curriculum_stages: List[int] = None # Skip the first N curriculum stages so we don't re-train on tiny seqs. skip_curriculum_stages: int = 2 # Patience (in eval periods) before ReduceLROnPlateau fires. plateau_patience: int = 3 # Factor to multiply LR by when plateau is detected. plateau_factor: float = 0.5 # Minimum improvement in perplexity to count as "not stalled". plateau_min_delta: float = 0.02 # EWC — set > 0 to enable anti-forgetting penalty during finetuning ewc_lambda: float = 0.0 ewc_samples: int = 2000 # samples used to estimate Fisher Information def __post_init__(self): if self.curriculum_stages is None: self.curriculum_stages = [256, 512, 1024] # ============================================================================ # ROTARY POSITIONAL EMBEDDINGS (RoPE) # ============================================================================ class RotaryEmbedding(nn.Module): """Rotary Positional Embeddings""" def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len: int): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, seq_len: int): if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype) ) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # ============================================================================ # RMS NORMALIZATION # ============================================================================ class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) # ============================================================================ # GROUPED-QUERY ATTENTION # ============================================================================ class GroupedQueryAttention(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta assert self.hidden_size % self.num_heads == 0 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta ) self.attention_dropout = nn.Dropout(config.attention_dropout) def forward(self, hidden_states, attention_mask=None): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, seq_len=q_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if self.num_key_value_groups > 1: key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output # ============================================================================ # SWIGLU FEEDFORWARD # ============================================================================ class SwiGLUMLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # ============================================================================ # DECODER LAYER # ============================================================================ class DecoderLayer(nn.Module): def __init__(self, config: ModelConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = GroupedQueryAttention(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = SwiGLUMLP(config) self.residual_dropout = nn.Dropout(config.residual_dropout) def forward(self, hidden_states, attention_mask=None): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask) hidden_states = self.residual_dropout(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.residual_dropout(hidden_states) hidden_states = residual + hidden_states return hidden_states # ============================================================================ # MAIN MODEL # ============================================================================ class IndonesianLLM(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx) self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.tie_word_embeddings: self.lm_head = None else: self.lm_head = nn.Linear(config.vocab_size, config.vocab_size, bias=False) self.apply(self._init_weights) def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_input_embeddings(self): return self.embed_tokens def _prepare_attention_mask(self, attention_mask, input_shape, dtype): batch_size, seq_length = input_shape causal_mask = torch.triu( torch.ones((seq_length, seq_length), dtype=torch.bool, device=attention_mask.device), diagonal=1 ) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length) if attention_mask is not None: expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length) expanded_mask = expanded_mask.bool() causal_mask = causal_mask | ~expanded_mask causal_mask = torch.where(causal_mask, torch.finfo(dtype).min, 0.0) return causal_mask def forward(self, input_ids, attention_mask=None, labels=None): batch_size, seq_length = input_ids.shape if attention_mask is None: attention_mask = torch.ones_like(input_ids) hidden_states = self.embed_tokens(input_ids) attention_mask = self._prepare_attention_mask( attention_mask, (batch_size, seq_length), hidden_states.dtype ) for decoder_layer in self.layers: hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask) hidden_states = self.norm(hidden_states) if self.lm_head is not None: logits = self.lm_head(hidden_states) else: logits = F.linear(hidden_states, self.embed_tokens.weight) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # FIX: use -100 as ignore_index (standard PyTorch convention) # prompt tokens are masked to -100 in __getitem__, padding also -100 loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1)) return {"loss": loss, "logits": logits} def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # ============================================================================ # DATASET # ============================================================================ class IndonesianCoTDataset(Dataset): def __init__( self, file_path: str, tokenizer, max_length: int = 1024, cot_token: str = "", end_cot_token: str = "", use_cot: bool = True, cot_ratio: float = 0.7 ): self.tokenizer = tokenizer self.max_length = max_length self.cot_token = cot_token self.end_cot_token = end_cot_token self.use_cot = use_cot self.cot_ratio = cot_ratio self.samples = [] self.skipped_count = 0 self._load_data(file_path) def _load_data(self, file_path: str): print(f"Loading dataset from {file_path}...") with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): try: if not line.strip(): continue data = json.loads(line) if not all(key in data for key in ['input', 'cot', 'output']): self.skipped_count += 1 print(f"Warning: Line {line_num} missing required fields, skipping...") continue if not all(isinstance(data[key], str) for key in ['input', 'cot', 'output']): self.skipped_count += 1 print(f"Warning: Line {line_num} has invalid data types, skipping...") continue if not all(data[key].strip() for key in ['input', 'cot', 'output']): self.skipped_count += 1 print(f"Warning: Line {line_num} has empty fields, skipping...") continue self.samples.append(data) except json.JSONDecodeError as e: self.skipped_count += 1 print(f"Warning: Line {line_num} is not valid JSON ({e}), skipping...") continue except Exception as e: self.skipped_count += 1 print(f"Warning: Line {line_num} caused error ({e}), skipping...") continue print(f"Loaded {len(self.samples)} valid samples") print(f"Skipped {self.skipped_count} malformed rows") if len(self.samples) == 0: raise ValueError("No valid samples loaded from dataset!") def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] # Build prompt (what the model sees as input) and completion (what it must learn to generate) if self.use_cot: if random.random() < self.cot_ratio: prompt = f"{sample['input']} " completion = f" {sample['cot']} {self.end_cot_token} {sample['output']}" else: prompt = f"{sample['input']}" completion = f" {sample['output']}" else: prompt = f"{sample['input']}" completion = f" {sample['output']}" # Encode prompt alone to know exactly where it ends in token space prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=True) prompt_len = len(prompt_ids) # Encode full sequence with truncation full_ids = self.tokenizer.encode( prompt + completion, max_length=self.max_length, truncation=True, padding=False, return_tensors=None, add_special_tokens=True, ) # FIX: mask prompt tokens with -100 so CrossEntropyLoss ignores them. # Only the completion (CoT reasoning + answer) contributes to the gradient. labels = [-100] * prompt_len + full_ids[prompt_len:] # Guard: if truncation cut into the prompt, clamp to full_ids length labels = labels[:len(full_ids)] return { 'input_ids': torch.tensor(full_ids, dtype=torch.long), 'labels': torch.tensor(labels, dtype=torch.long), 'length': len(full_ids), 'cot_length': len(self.tokenizer.encode( sample['cot'], add_special_tokens=False )) if self.use_cot else 0, } def collate_fn_with_packing(batch, pad_token_id=0): # FIX: use pre-masked labels from dataset instead of copying input_ids. # Padding positions use -100 so they are also ignored by the loss. batch = sorted(batch, key=lambda x: x['length'], reverse=True) max_length = max(item['length'] for item in batch) input_ids_batch = [] attention_mask_batch = [] labels_batch = [] for item in batch: input_ids = item['input_ids'] labels = item['labels'] length = item['length'] pad_len = max_length - length input_ids_padded = F.pad(input_ids, (0, pad_len), value=pad_token_id) labels_padded = F.pad(labels, (0, pad_len), value=-100) attention_mask = torch.cat([ torch.ones(length, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long) ]) input_ids_batch.append(input_ids_padded) attention_mask_batch.append(attention_mask) labels_batch.append(labels_padded) return { 'input_ids': torch.stack(input_ids_batch), 'attention_mask': torch.stack(attention_mask_batch), 'labels': torch.stack(labels_batch), } # ============================================================================ # CURRICULUM LEARNING # ============================================================================ def create_curriculum_datasets(dataset, stages=[256, 512, 1024], use_simple=False, skip_stages=0): """ Build per-stage datasets. FIX (simple curriculum): skipped stages now use `continue` so they are never built — no tokenizer.encode() calls wasted on stages that get thrown away. The slice at the bottom is moved INSIDE the else-branch only, so the simple-curriculum list is never accidentally emptied. """ datasets = [] if use_simple: for i, max_len in enumerate(stages): # FIX: skip immediately — no encoding, no dataset build if i < skip_stages: print(f"[SKIP] Curriculum stage {max_len}: skipped (not built)") continue filtered_samples = [ s for s in dataset.samples if len(dataset.tokenizer.encode( f"{s['input']} {dataset.cot_token} {s['cot']} {dataset.end_cot_token} {s['output']}" )) <= max_len ] stage_dataset = _build_stage_dataset(dataset, filtered_samples, max_len, dataset.cot_ratio) datasets.append(stage_dataset) print(f"Curriculum stage {max_len}: {len(filtered_samples)} samples") # NOTE: no slice here — skipped stages were never added to datasets else: print("\n" + "="*80) print("3-STAGE REASONING CURRICULUM") if skip_stages > 0: print(f" (Skipping first {skip_stages} stage(s) — continue-train mode)") print("="*80) stage_configs = [ { 'name': 'Stage 1: Basic Q&A (short, no reasoning)', 'max_len': 384, 'cot_ratio': 0.0, 'filter': lambda s: len(dataset.tokenizer.encode( f"{s['input']} {s['output']}" )) <= 384 }, { 'name': 'Stage 2: Learning Reasoning (medium, 50% CoT)', 'max_len': 512, 'cot_ratio': 0.5, 'filter': lambda s: True }, { 'name': 'Stage 3: Full Reasoning (all, 100% CoT)', 'max_len': 1024, 'cot_ratio': 1.0, 'filter': lambda s: True } ] for idx, stage_config in enumerate(stage_configs): filtered_samples = [s for s in dataset.samples if stage_config['filter'](s)] stage_dataset = _build_stage_dataset( dataset, filtered_samples, stage_config['max_len'], stage_config['cot_ratio'] ) datasets.append(stage_dataset) skipped = idx < skip_stages prefix = " [SKIP] " if skipped else " " print(f"{prefix}{stage_config['name']}") print(f" {'(skipped)' if skipped else ''} Samples: {len(filtered_samples)}") print(f" {'(skipped)' if skipped else ''} Max length: {stage_config['max_len']}") print(f" {'(skipped)' if skipped else ''} CoT ratio: {stage_config['cot_ratio']*100:.0f}%") print("="*80 + "\n") # FIX: slice lives inside else-branch only, safe for 3-stage mode if skip_stages > 0: datasets = datasets[skip_stages:] return datasets def _build_stage_dataset(base_dataset, samples, max_len, cot_ratio): """Helper: create a shallow-copy stage dataset from a list of samples.""" stage = IndonesianCoTDataset.__new__(IndonesianCoTDataset) stage.tokenizer = base_dataset.tokenizer stage.max_length = max_len stage.cot_token = base_dataset.cot_token stage.end_cot_token = base_dataset.end_cot_token stage.samples = samples stage.skipped_count = 0 stage.use_cot = base_dataset.use_cot stage.cot_ratio = cot_ratio return stage # ============================================================================ # LEARNING-RATE SCHEDULERS # ============================================================================ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def get_continue_schedule(optimizer, num_training_steps: int, min_fraction: float = 0.1): """ Schedule for --continue-train without saved optimizer state. Starts at target LR immediately and decays gently via cosine to min_fraction x LR by the end of training. """ def lr_lambda(step): progress = float(step) / float(max(1, num_training_steps)) cosine_val = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) return min_fraction + (1.0 - min_fraction) * cosine_val return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) class PlateauLRGuard: """ Wraps any LambdaLR scheduler and applies an extra multiplicative penalty when perplexity has not improved for `patience` consecutive checks. """ def __init__(self, scheduler, patience=3, factor=0.5, min_delta=0.02): self.scheduler = scheduler self.patience = patience self.factor = factor self.min_delta = min_delta self._best = float('inf') self._no_improve = 0 self._penalty = 1.0 def step(self, perplexity: float): relative_improvement = (self._best - perplexity) / max(self._best, 1e-8) if relative_improvement > self.min_delta: self._best = perplexity self._no_improve = 0 else: self._no_improve += 1 if self._no_improve >= self.patience: self._penalty *= self.factor self._no_improve = 0 new_lr = self.scheduler.get_last_lr()[0] * self.factor print(f"\n[PlateauLRGuard] No improvement for {self.patience} checks. " f"Reducing LR by {self.factor:.2f}x -> {new_lr:.2e}") for pg in self.scheduler.optimizer.param_groups: pg['lr'] *= self.factor return True return False def get_penalty(self): return self._penalty # ============================================================================ # TRAINING UTILITIES # ============================================================================ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ============================================================================ # ELASTIC WEIGHT CONSOLIDATION (EWC) # ============================================================================ class EWC: """ Elastic Weight Consolidation — prevents catastrophic forgetting during finetuning. """ def __init__(self, model, dataloader, device, n_samples: int = 2000): self.device = device self.n_samples = n_samples self.params = { n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad } self.fisher = self._compute_fisher(model, dataloader) def _compute_fisher(self, model, dataloader): fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad} model.eval() seen = 0 for batch in dataloader: if seen >= self.n_samples: break input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) model.zero_grad() outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) outputs["loss"].backward() for n, p in model.named_parameters(): if p.requires_grad and p.grad is not None: fisher[n] += p.grad.detach().pow(2) seen += input_ids.size(0) for n in fisher: fisher[n] /= max(1, seen) model.train() return fisher def penalty(self, model) -> torch.Tensor: loss = torch.tensor(0.0, device=self.device) for n, p in model.named_parameters(): if p.requires_grad and n in self.fisher: loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum() return loss * 0.5 # ============================================================================ # TRAINING LOOP # ============================================================================ def train_model( model: IndonesianLLM, train_dataset: IndonesianCoTDataset, config: TrainingConfig, device: torch.device, use_simple_curriculum: bool = False, is_continue: bool = False, skip_curriculum_stages: int = 0, ewc: "EWC | None" = None, ): """Main training loop.""" print("\n" + "="*80) print("TRAINING CONFIGURATION" + (" [CONTINUE MODE]" if is_continue else "")) print("="*80) print(f"Model parameters: {model.count_parameters():,}") print(f"Dataset size: {len(train_dataset)}") print(f"Batch size: {config.batch_size}") print(f"Gradient accumulation steps: {config.gradient_accumulation_steps}") print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") print(f"Learning rate: {config.learning_rate}") print(f"Max sequence length: {config.max_seq_length}") print(f"Number of epochs: {config.num_epochs}") print(f"Mixed precision: {config.use_fp16}") if is_continue: print(f"Skipping curriculum stages: {skip_curriculum_stages}") print(f"Plateau patience: {config.plateau_patience}") print(f"Plateau LR factor: {config.plateau_factor}") if ewc is not None: print(f"EWC lambda: {config.ewc_lambda} (anti-forgetting active)") print("="*80 + "\n") model.to(device) model.train() curriculum_datasets = create_curriculum_datasets( train_dataset, config.curriculum_stages, use_simple=use_simple_curriculum, skip_stages=skip_curriculum_stages, ) if len(curriculum_datasets) == 0: print("ERROR: No curriculum stages to train on. Check --skip-stages value.") return model optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), eps=config.adam_epsilon, weight_decay=config.weight_decay ) # Calculate total steps total_steps = 0 for ds in curriculum_datasets: steps_per_epoch = max(1, len(ds) // (config.batch_size * config.gradient_accumulation_steps)) total_steps += steps_per_epoch * config.num_epochs if total_steps == 0: total_steps = 1 # LR Scheduler if is_continue: scheduler = get_continue_schedule(optimizer, num_training_steps=total_steps) plateau_guard = PlateauLRGuard( scheduler, patience=config.plateau_patience, factor=config.plateau_factor, min_delta=config.plateau_min_delta, ) print(f"[Scheduler] Continue-train: flat cosine decay from {config.learning_rate:.2e}") else: if config.lr_scheduler_type == "cosine": scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps ) else: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps ) plateau_guard = None scaler = torch.cuda.amp.GradScaler() if config.use_fp16 and torch.cuda.is_available() else None global_step = 0 perplexity_history = [] for stage_idx, stage_dataset in enumerate(curriculum_datasets): actual_stage = stage_idx + skip_curriculum_stages print(f"\n{'='*80}") print(f"CURRICULUM STAGE {actual_stage + 1}/{len(curriculum_datasets) + skip_curriculum_stages} " f"(running {stage_idx + 1} of {len(curriculum_datasets)})") print(f"Max sequence length: {stage_dataset.max_length}") print(f"Samples: {len(stage_dataset)}") if hasattr(stage_dataset, 'cot_ratio'): print(f"CoT ratio: {stage_dataset.cot_ratio:.0%}") print(f"{'='*80}\n") dataloader = DataLoader( stage_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), num_workers=0, pin_memory=True if torch.cuda.is_available() else False ) for epoch in range(config.num_epochs): print(f"\nEpoch {epoch + 1}/{config.num_epochs}") epoch_loss = 0.0 optimizer.zero_grad() for step, batch in enumerate(dataloader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) if scaler is not None: with torch.cuda.amp.autocast(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) task_loss = outputs['loss'] if ewc is not None: task_loss = task_loss + config.ewc_lambda * ewc.penalty(model) loss = task_loss / config.gradient_accumulation_steps scaler.scale(loss).backward() else: outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) task_loss = outputs['loss'] if ewc is not None: task_loss = task_loss + config.ewc_lambda * ewc.penalty(model) loss = task_loss / config.gradient_accumulation_steps loss.backward() epoch_loss += loss.item() * config.gradient_accumulation_steps if (step + 1) % config.gradient_accumulation_steps == 0: if scaler is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) if scaler is not None: scaler.step(optimizer) scaler.update() else: optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 if global_step % config.logging_steps == 0: avg_loss = epoch_loss / (step + 1) current_lr = scheduler.get_last_lr()[0] print(f"Step {global_step:>6} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") avg_epoch_loss = epoch_loss / max(1, len(dataloader)) perplexity = math.exp(min(avg_epoch_loss, 20)) perplexity_history.append(perplexity) print(f"Epoch {epoch + 1} completed | Avg Loss: {avg_epoch_loss:.4f} " f"| Perplexity: {perplexity:.2f}") if plateau_guard is not None: plateau_guard.step(perplexity) print("\n" + "="*80) print("TRAINING COMPLETED") if perplexity_history: print(f"Final perplexity: {perplexity_history[-1]:.2f}") print(f"Best perplexity: {min(perplexity_history):.2f}") print("="*80 + "\n") return model # ============================================================================ # EVALUATION # ============================================================================ def evaluate_model(model, dataset, device, batch_size=4): model.eval() dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), num_workers=0 ) total_loss = 0.0 total_samples = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) total_loss += outputs['loss'].item() * input_ids.size(0) total_samples += input_ids.size(0) avg_loss = total_loss / max(1, total_samples) perplexity = math.exp(min(avg_loss, 20)) print(f"\nEvaluation Results:") print(f"Average Loss: {avg_loss:.4f}") print(f"Perplexity: {perplexity:.2f}") if perplexity < 5.0: print("Status: Excellent") elif perplexity < 10.0: print("Status: Good") elif perplexity < 20.0: print("Status: Fair — try more epochs or lower LR") else: print("Status: Poor — check data quality or model config") if avg_loss < 0.5: print("Warning: Very low loss might indicate overfitting") return {"loss": avg_loss, "perplexity": perplexity} # ============================================================================ # GENERATION # ============================================================================ def generate_text( model, tokenizer, prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9, device: torch.device = torch.device('cpu') ): model.eval() input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) generated_ids = input_ids.clone() eos_token_id = tokenizer.eos_token_id if eos_token_id is None: eos_token_id = tokenizer.sep_token_id if eos_token_id is None: eos_token_id = 2 stop_tokens = {eos_token_id, tokenizer.pad_token_id} if tokenizer.sep_token_id is not None: stop_tokens.add(tokenizer.sep_token_id) repetition_buffer = [] with torch.no_grad(): for step in range(max_new_tokens): outputs = model(input_ids=generated_ids) logits = outputs['logits'] next_token_logits = logits[:, -1, :] / max(temperature, 0.1) if len(repetition_buffer) > 10: for token in set(repetition_buffer[-10:]): if token in stop_tokens: continue next_token_logits[0, token] -= 2.0 if top_k > 0: indices_to_remove = next_token_logits < torch.topk( next_token_logits, min(top_k, next_token_logits.size(-1)) )[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() in stop_tokens: break repetition_buffer.append(next_token.item()) if len(repetition_buffer) > 20: repetition_buffer.pop(0) generated_ids = torch.cat([generated_ids, next_token], dim=1) max_ctx = model.config.max_position_embeddings if generated_ids.size(1) > max_ctx: generated_ids = generated_ids[:, -max_ctx:] if step > 10: decoded = tokenizer.decode( generated_ids[0][input_ids.size(1):], skip_special_tokens=False) if '\n\n' in decoded or 'User:' in decoded or 'Assistant:' in decoded[-20:]: break return tokenizer.decode(generated_ids[0], skip_special_tokens=False) # ============================================================================ # INTERACTIVE CHAT # ============================================================================ def _clean_response(response: str) -> str: import re if "" in response and "" in response: response = response.split("", 1)[-1] elif "" in response: response = response.split("", 1)[0] response = re.sub(r'\[\w+\]', '', response) response = re.sub(r'<[^>]+>', '', response) for marker in [ "user :", "user:", "User :", "User:", "assistant :", "assistant:", "Assistant :", "Assistant:", "memahami permintaan", "jawaban singkat", "penjelasan harus", "\n\n", ]: if marker in response: response = response.split(marker)[0] response = re.sub(r'^[\s:!,\.\-|\[\]]+', '', response) response = re.sub(r' {2,}', ' ', response).strip() return response def _extract_thinking(raw: str) -> tuple: import re raw = re.sub(r'\[\w+\]', '', raw) if "" in raw: thinking_raw, answer_raw = raw.split("", 1) else: thinking_raw, answer_raw = raw, "" thinking = thinking_raw.strip() for marker in ["user :", "user:", "memahami permintaan", "\n\n"]: if marker in thinking: thinking = thinking.split(marker)[0] thinking = re.sub(r'<[^>]+>', '', thinking).strip() answer = _clean_response(answer_raw) return thinking, answer def interactive_chat(model, tokenizer, device, system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia."): print("\n" + "="*80) print("INDONESIAN LLM — INTERACTIVE CHAT") print("="*80) print("Commands: 'exit'/'quit' | 'clear' | 'think' (toggle reasoning display)") print(f"Persona : {system_prompt}") print("="*80 + "\n") model.eval() show_thinking = False import time torch.manual_seed(int(time.time()) % 100000) if torch.cuda.is_available(): torch.cuda.manual_seed_all(int(time.time()) % 100000) while True: try: user_input = input("\nYou: ").strip() if not user_input: continue if user_input.lower() in ['exit', 'quit', 'keluar']: print("\nGoodbye!") break if user_input.lower() in ['clear', 'bersihkan']: print("\nConversation cleared") continue if user_input.lower() == 'think': show_thinking = not show_thinking print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}") continue prompt = f"{user_input} " max_tokens = 250 print("\nA:", end=" ", flush=True) full_response = generate_text( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_tokens, temperature=0.9, top_k=50, top_p=0.92, device=device ) response = full_response[len(prompt):].strip() thinking, answer = _extract_thinking(response) if show_thinking and thinking: print(f"[Thinking: {thinking}]") final = answer if answer else _clean_response(response) if not final or len(final) < 3: final = "Maaf, saya tidak mengerti. Bisa diulang?" print(final) except KeyboardInterrupt: print("\n\nChat interrupted") break except Exception as e: print(f"\nError: {e}") # ============================================================================ # BENCHMARK # ============================================================================ def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 20, verbose: bool = True): import time if dataset_path is None or not os.path.exists(dataset_path): print(f"Dataset not found at: {dataset_path}") print("Pass --dataset path/to/your.jsonl to benchmark against your data.") return all_samples = [] with open(dataset_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue try: d = json.loads(line) if all(k in d for k in ['input', 'output']) and d['input'].strip() and d['output'].strip(): all_samples.append(d) except Exception: continue if not all_samples: print("No valid samples found in dataset.") return random.seed(int(time.time())) samples = random.sample(all_samples, min(n, len(all_samples))) model.eval() torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed_all(42) print("\n" + "="*80) print(f"BENCHMARK ({len(samples)} random samples from dataset)") print("="*80) results = [] ppl_loss = 0.0 ppl_toks = 0 for i, sample in enumerate(samples): inp = sample['input'].strip() expected = sample['output'].strip().lower() prompt = f"{inp} " full = generate_text( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=150, temperature=0.3, top_k=20, top_p=0.9, device=device ) raw = full[len(prompt):].strip() _, answer = _extract_thinking(raw) answer_lower = answer.lower() passed = expected in answer_lower if not passed: exp_tokens = set(expected.split()) ans_tokens = set(answer_lower.split()) if exp_tokens: overlap = len(exp_tokens & ans_tokens) / len(exp_tokens) passed = overlap >= 0.5 results.append(passed) with torch.no_grad(): ids = tokenizer.encode( f"{inp} {sample['output']}", return_tensors="pt" ).to(device) if ids.size(1) >= 2: out = model(input_ids=ids, labels=ids) toks = ids.size(1) - 1 ppl_loss += out["loss"].item() * toks ppl_toks += toks if verbose: status = "PASS" if passed else "FAIL" print(f" [{status}] {inp[:60]}") print(f" Expected : {sample['output'][:80]}") print(f" Got : {answer[:80] if answer else '(no answer)'}") total_pass = sum(results) total = len(results) overall = total_pass / total * 100 bar = "█" * int(overall / 10) + "░" * (10 - int(overall / 10)) ppl = math.exp(min(ppl_loss / max(1, ppl_toks), 20)) print("\n" + "-"*80) print(f" SCORE {total_pass}/{total} ({overall:.1f}%) {bar}") print(f" PERPLEXITY {ppl:.2f} (lower = better)") print("="*80 + "\n") return {"score": overall, "pass": total_pass, "total": total, "perplexity": ppl} # ============================================================================ # MODEL SAVING AND LOADING # ============================================================================ def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, save_path: str, use_fp16: bool = True): os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True) state_dict = model.state_dict() if use_fp16: state_dict = {k: v.half() if v.dtype == torch.float32 else v for k, v in state_dict.items()} torch.save({ 'model_state_dict': state_dict, 'config': config, 'tokenizer_name': tokenizer_name, 'model_params': model.count_parameters(), 'dtype': 'fp16' if use_fp16 else 'fp32', }, save_path) size_mb = os.path.getsize(save_path) / 1e6 print(f"\nModel saved to: {save_path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB)") print(f"Parameters: {model.count_parameters():,}") def load_model(load_path: str, device: torch.device): if not os.path.exists(load_path): raise FileNotFoundError(f"Model checkpoint not found: {load_path}") print(f"Loading model from: {load_path}") checkpoint = torch.load(load_path, map_location=device, weights_only=False) config = checkpoint['config'] tokenizer_name = checkpoint['tokenizer_name'] dtype = checkpoint.get('dtype', 'fp32') tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}) model = IndonesianLLM(config) state_dict = checkpoint['model_state_dict'] if dtype == 'fp16': state_dict = {k: v.float() if v.dtype == torch.float16 else v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(device) extra = {'training_metadata': checkpoint.get('training_metadata', None)} size_mb = os.path.getsize(load_path) / 1e6 print(f"Model loaded ({dtype}, {size_mb:.1f} MB) | " f"Parameters: {checkpoint.get('model_params', model.count_parameters()):,}") return model, tokenizer, config, extra # ============================================================================ # MAIN # ============================================================================ def main(): parser = argparse.ArgumentParser( description="Indonesian Conversational LLM — Train, Chat, Finetune, or Continue") parser.add_argument('--train', action='store_true', help='Train model from scratch') parser.add_argument('--chat', action='store_true', help='Interactive chat mode') parser.add_argument('--finetune', action='store_true', help='Fine-tune on NEW data (lr/10)') parser.add_argument('--continue-train', action='store_true', help='Continue training on SAME data with proper LR re-warmup') parser.add_argument('--inspect-data', action='store_true', help='Inspect dataset quality') parser.add_argument('--benchmark', action='store_true', help='Run benchmark suite on a saved model') parser.add_argument('--no-eval', action='store_true', help='Skip evaluation after training') parser.add_argument('--grad-accum', type=int, default=None, help='Override gradient accumulation steps') parser.add_argument('--ewc-lambda', type=float, default=5000.0, help='EWC penalty strength for --finetune (default 5000). 0 = disabled.') parser.add_argument('--ewc-samples', type=int, default=2000, help='Samples used to estimate Fisher Information for EWC (default 2000).') parser.add_argument('--no-ewc', action='store_true', help='Disable EWC during finetuning.') parser.add_argument('--dataset', type=str, default='indonesian_cot_dataset.jsonl') parser.add_argument('--model', type=str, default='indonesian_llm_model.pt') parser.add_argument('--epochs', type=int, default=5) parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--max-length', type=int, default=512, help='Max sequence length. Dataset max is 448 so 512 is optimal.') parser.add_argument('--hidden-size', type=int, default=320) parser.add_argument('--num-layers', type=int, default=16) parser.add_argument('--num-heads', type=int, default=8) parser.add_argument('--num-kv-heads', type=int, default=2) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--save-fp16', action='store_true', default=True) parser.add_argument('--save-fp32', action='store_true') parser.add_argument('--no-cot', action='store_true') parser.add_argument('--use-cot', action='store_true', default=True) parser.add_argument('--simple-curriculum', action='store_true') parser.add_argument('--cot-ratio', type=float, default=1.0) parser.add_argument('--system-prompt', type=str, default='Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.') parser.add_argument('--rewarm-steps', type=int, default=150) parser.add_argument('--skip-stages', type=int, default=2, help='Curriculum stages to skip in continue-train (default 2).') parser.add_argument('--plateau-patience', type=int, default=3) parser.add_argument('--no-restore-optimizer', action='store_true') args = parser.parse_args() if not any([args.train, args.chat, args.finetune, args.continue_train, args.inspect_data, args.benchmark]): parser.print_help() print("\nError: Specify a mode: --train, --chat, --finetune, --continue-train, " "--inspect-data, or --benchmark") return save_fp16 = not args.save_fp32 use_cot_training = not args.no_cot set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\nDevice: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n") # ------------------------------------------------------------------ # INSPECT DATA # ------------------------------------------------------------------ if args.inspect_data: print("\nInspecting dataset...") print("="*80) tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased") tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}) dataset = IndonesianCoTDataset( file_path=args.dataset, tokenizer=tokenizer, max_length=args.max_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio ) print(f"\nDataset Statistics:") print(f"Total samples: {len(dataset)}") print(f"Skipped samples: {dataset.skipped_count}") lengths = [] cot_lengths = [] for i in range(min(len(dataset), 1000)): sample = dataset[i] lengths.append(sample['length']) cot_lengths.append(sample['cot_length']) print(f"\nSequence Length Stats:") print(f" Min: {min(lengths)}") print(f" Max: {max(lengths)}") print(f" Avg: {sum(lengths)/len(lengths):.1f}") print(f" Median: {sorted(lengths)[len(lengths)//2]}") print(f"\nCoT Length Stats:") print(f" Min: {min(cot_lengths)}") print(f" Max: {max(cot_lengths)}") print(f" Avg: {sum(cot_lengths)/len(cot_lengths):.1f}") long_cot = sum(1 for x in cot_lengths if x > 50) print(f" Samples with long CoT (>50 tokens): {long_cot} ({long_cot/len(cot_lengths)*100:.1f}%)") print(f"\n{'='*80}") print("Sample Examples (first 5):") print("="*80) for i in range(min(5, len(dataset.samples))): s = dataset.samples[i] print(f"\n--- Sample {i+1} ---") print(f"Input: {s['input'][:100]}...") print(f"CoT: {s['cot'][:150]}...") print(f"Output: {s['output'][:100]}...") print("\n" + "="*80) print("Dataset Quality Checks:") issues = [] for i, sample in enumerate(dataset.samples[:100]): if len(sample['input']) < 10: issues.append(f"Sample {i}: Input too short") if len(sample['output']) < 10: issues.append(f"Sample {i}: Output too short") if len(sample['cot']) < 20: issues.append(f"Sample {i}: CoT too short") if sample['input'].lower() == sample['output'].lower(): issues.append(f"Sample {i}: Input == Output (copy)") if issues: print(f"\nFound {len(issues)} potential issues in first 100 samples:") for issue in issues[:10]: print(f" - {issue}") if len(issues) > 10: print(f" ... and {len(issues)-10} more") else: print("\nNo obvious issues detected in first 100 samples") print("\n" + "="*80) return # ------------------------------------------------------------------ # CHAT # ------------------------------------------------------------------ if args.chat: print("\nStarting CHAT mode...") if not os.path.exists(args.model): print(f"Error: Model checkpoint not found: {args.model}") return model, tokenizer, config, _ = load_model(args.model, device) interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt) return # ------------------------------------------------------------------ # BENCHMARK # ------------------------------------------------------------------ if args.benchmark: print("\nRunning benchmark...") if not os.path.exists(args.model): print(f"Error: Model not found: {args.model}") return model, tokenizer, config, _ = load_model(args.model, device) run_benchmark(model, tokenizer, device, dataset_path=args.dataset) return # ------------------------------------------------------------------ # TRAIN FROM SCRATCH # ------------------------------------------------------------------ if args.train: print("\nStarting TRAINING mode (from scratch)...") tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased") tokenizer.add_special_tokens({"additional_special_tokens": ["", ""]}) model_config = ModelConfig( vocab_size=len(tokenizer), hidden_size=args.hidden_size, num_layers=args.num_layers, num_attention_heads=args.num_heads, num_key_value_heads=args.num_kv_heads, intermediate_size=args.hidden_size * 3, max_position_embeddings=2048, attention_dropout=0.1, residual_dropout=0.1, tie_word_embeddings=True ) model = IndonesianLLM(model_config) print(f"Model parameters: {model.count_parameters():,}") _ga = args.grad_accum if args.grad_accum else 32 train_config = TrainingConfig( dataset_path=args.dataset, num_epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation_steps=_ga, max_seq_length=args.max_length, learning_rate=args.lr, warmup_steps=500, use_fp16=torch.cuda.is_available(), curriculum_stages=[128, 256, args.max_length] ) dataset = IndonesianCoTDataset( file_path=train_config.dataset_path, tokenizer=tokenizer, max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio ) model = train_model( model, dataset, train_config, device, use_simple_curriculum=args.simple_curriculum ) if not args.no_eval: evaluate_model(model, dataset, device) save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16) test_prompts = [ "Berapa hasil dari 1+1?", "Jelaskan cara kerja komputer", "Bagaimana cara membuat kopi yang enak?" ] print("\n" + "="*80) print("GENERATION TEST") print("="*80 + "\n") for prompt in test_prompts: print(f"Prompt: {prompt}") generated = generate_text(model, tokenizer, prompt, max_new_tokens=150, device=device) print(f"Generated: {generated}\n") print("-" * 80 + "\n") print(f"\nTo chat: python {__file__} --chat --model {args.model}") # ------------------------------------------------------------------ # FINETUNE (new data, lr/10) # ------------------------------------------------------------------ if args.finetune: print("\nStarting FINETUNE mode (for NEW data)...") if not os.path.exists(args.model): print(f"Error: Model checkpoint not found: {args.model}") return model, tokenizer, model_config, extra = load_model(args.model, device) _ga = args.grad_accum if args.grad_accum else 32 train_config = TrainingConfig( dataset_path=args.dataset, num_epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation_steps=_ga, max_seq_length=args.max_length, learning_rate=args.lr / 10, warmup_steps=100, use_fp16=torch.cuda.is_available(), curriculum_stages=[128, 256, args.max_length] ) dataset = IndonesianCoTDataset( file_path=train_config.dataset_path, tokenizer=tokenizer, max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio ) print(f"\nStarting fine-tuning with LR={train_config.learning_rate:.2e}...") ewc_obj = None if not args.no_ewc and args.ewc_lambda > 0: print(f"\nComputing EWC Fisher Information " f"(lambda={args.ewc_lambda}, samples={args.ewc_samples})...") print(" This takes ~1-2 min on T4. Prevents forgetting old training data.") old_dataset = IndonesianCoTDataset( file_path=args.dataset, tokenizer=tokenizer, max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio ) old_loader = DataLoader( old_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx), num_workers=0 ) train_config.ewc_lambda = args.ewc_lambda train_config.ewc_samples = args.ewc_samples ewc_obj = EWC(model, old_loader, device, n_samples=args.ewc_samples) print(f" Fisher computed. EWC penalty will be added during training.") else: print(" EWC disabled — model may forget previous training.") model = train_model( model, dataset, train_config, device, use_simple_curriculum=args.simple_curriculum, ewc=ewc_obj, ) if not args.no_eval: evaluate_model(model, dataset, device) finetuned_path = args.model.replace('.pt', '_finetuned.pt') save_model(model, model_config, "indolem/indobert-base-uncased", finetuned_path, use_fp16=save_fp16) print(f"\nFine-tuning completed. Model saved to: {finetuned_path}") print(f"To chat: python {__file__} --chat --model {finetuned_path}") # ------------------------------------------------------------------ # CONTINUE TRAINING # ------------------------------------------------------------------ if args.continue_train: print("\nStarting CONTINUE-TRAIN mode...") print("="*80) print("NOTE: For a 15-30M param model on Indonesian CoT, perplexity 5-15 is normal.") print("Key improvements:") print(" 1. Early curriculum stages skipped → no wasted time on short sequences") print(" 2. Plateau LR reduction → auto-halves LR if perplexity stalls") print(" 3. Micro LR on cold Adam → no destructive spike at restart") print("="*80) if not os.path.exists(args.model): print(f"Error: Model checkpoint not found: {args.model}") return model, tokenizer, model_config, extra = load_model(args.model, device) effective_lr = args.lr * 0.05 print(f"\nContinue-train LR: {args.lr:.2e} x 0.05 = {effective_lr:.2e}") print(f" Adam cold-starts — micro LR prevents overshooting the trained minimum.") curriculum_stages = [192, 320, args.max_length] if args.simple_curriculum: effective_skip = len(curriculum_stages) - 1 else: effective_skip = args.skip_stages print(f"Skipping first {effective_skip} curriculum stage(s) — " f"training on full-length data only.") _ga = args.grad_accum if args.grad_accum else 32 train_config = TrainingConfig( dataset_path=args.dataset, num_epochs=args.epochs, batch_size=args.batch_size, gradient_accumulation_steps=_ga, max_seq_length=args.max_length, learning_rate=effective_lr, warmup_steps=0, use_fp16=torch.cuda.is_available(), curriculum_stages=curriculum_stages, skip_curriculum_stages=effective_skip, plateau_patience=2, plateau_factor=0.5, plateau_min_delta=0.02, ) dataset = IndonesianCoTDataset( file_path=train_config.dataset_path, tokenizer=tokenizer, max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio ) model = train_model( model, dataset, train_config, device, use_simple_curriculum=args.simple_curriculum, is_continue=True, skip_curriculum_stages=effective_skip, ) if not args.no_eval: print("\nEvaluating continued model...") evaluate_model(model, dataset, device) save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16) print(f"\nContinued training completed.") print(f"Model saved to: {args.model}") print(f"To chat: python {__file__} --chat --model {args.model}") if __name__ == "__main__": main()