NousAPI / best.py
FaiziRBLX's picture
Update best.py
a24cdb2 verified
"""
Production-Grade Indonesian Conversational Language Model
Trained from scratch with Chain-of-Thought reasoning capability
Architecture: Decoder-only transformer with MQA/GQA, RoPE, SwiGLU, RMSNorm
Target: 15M-30M parameters, optimized for Google Colab Free tier
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import json
import math
import random
import numpy as np
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
import warnings
import argparse
import os
warnings.filterwarnings('ignore')
# ============================================================================
# CONFIGURATION
# ============================================================================
@dataclass
class ModelConfig:
"""Model architecture configuration"""
vocab_size: int = 30000
hidden_size: int = 384
num_layers: int = 12
num_attention_heads: int = 6
num_key_value_heads: int = 2 # GQA: 2 KV heads, MQA: 1 KV head
intermediate_size: int = 1024
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
attention_dropout: float = 0.1
residual_dropout: float = 0.1
initializer_range: float = 0.02
use_cache: bool = False
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
tie_word_embeddings: bool = True
def __post_init__(self):
assert self.hidden_size % self.num_attention_heads == 0
assert self.num_attention_heads % self.num_key_value_heads == 0
@dataclass
class TrainingConfig:
"""Training hyperparameters"""
dataset_path: str = "indonesian_cot_dataset.jsonl"
output_dir: str = "./indonesian_llm_checkpoints"
# Training
num_epochs: int = 3
batch_size: int = 4
gradient_accumulation_steps: int = 12
max_seq_length: int = 1024
# Optimization
learning_rate: float = 3e-4
weight_decay: float = 0.01
adam_beta1: float = 0.9
adam_beta2: float = 0.95
adam_epsilon: float = 1e-8
max_grad_norm: float = 1.0
# Scheduler
warmup_steps: int = 100
lr_scheduler_type: str = "cosine"
# Regularization
dropout: float = 0.1
# Mixed precision
use_fp16: bool = True
# Reproducibility
seed: int = 42
# Logging
logging_steps: int = 10
eval_steps: int = 100
save_steps: int = 500
# Curriculum learning
curriculum_stages: List[int] = None
# Skip the first N curriculum stages so we don't re-train on tiny seqs.
skip_curriculum_stages: int = 2
# Patience (in eval periods) before ReduceLROnPlateau fires.
plateau_patience: int = 3
# Factor to multiply LR by when plateau is detected.
plateau_factor: float = 0.5
# Minimum improvement in perplexity to count as "not stalled".
plateau_min_delta: float = 0.02
# EWC — set > 0 to enable anti-forgetting penalty during finetuning
ewc_lambda: float = 0.0
ewc_samples: int = 2000 # samples used to estimate Fisher Information
def __post_init__(self):
if self.curriculum_stages is None:
self.curriculum_stages = [256, 512, 1024]
# ============================================================================
# ROTARY POSITIONAL EMBEDDINGS (RoPE)
# ============================================================================
class RotaryEmbedding(nn.Module):
"""Rotary Positional Embeddings"""
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype)
)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# ============================================================================
# RMS NORMALIZATION
# ============================================================================
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# ============================================================================
# GROUPED-QUERY ATTENTION
# ============================================================================
class GroupedQueryAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
assert self.hidden_size % self.num_heads == 0
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward(self, hidden_states, attention_mask=None):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=q_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if self.num_key_value_groups > 1:
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
# ============================================================================
# SWIGLU FEEDFORWARD
# ============================================================================
class SwiGLUMLP(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ============================================================================
# DECODER LAYER
# ============================================================================
class DecoderLayer(nn.Module):
def __init__(self, config: ModelConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = GroupedQueryAttention(config)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = SwiGLUMLP(config)
self.residual_dropout = nn.Dropout(config.residual_dropout)
def forward(self, hidden_states, attention_mask=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
hidden_states = self.residual_dropout(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.residual_dropout(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# ============================================================================
# MAIN MODEL
# ============================================================================
class IndonesianLLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.tie_word_embeddings:
self.lm_head = None
else:
self.lm_head = nn.Linear(config.vocab_size, config.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.embed_tokens
def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
batch_size, seq_length = input_shape
causal_mask = torch.triu(
torch.ones((seq_length, seq_length), dtype=torch.bool, device=attention_mask.device),
diagonal=1
)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length)
if attention_mask is not None:
expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
expanded_mask = expanded_mask.bool()
causal_mask = causal_mask | ~expanded_mask
causal_mask = torch.where(causal_mask, torch.finfo(dtype).min, 0.0)
return causal_mask
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_length = input_ids.shape
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
hidden_states = self.embed_tokens(input_ids)
attention_mask = self._prepare_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states.dtype
)
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm(hidden_states)
if self.lm_head is not None:
logits = self.lm_head(hidden_states)
else:
logits = F.linear(hidden_states, self.embed_tokens.weight)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# FIX: use -100 as ignore_index (standard PyTorch convention)
# prompt tokens are masked to -100 in __getitem__, padding also -100
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
return {"loss": loss, "logits": logits}
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ============================================================================
# DATASET
# ============================================================================
class IndonesianCoTDataset(Dataset):
def __init__(
self,
file_path: str,
tokenizer,
max_length: int = 1024,
cot_token: str = "<cot>",
end_cot_token: str = "</cot>",
use_cot: bool = True,
cot_ratio: float = 0.7
):
self.tokenizer = tokenizer
self.max_length = max_length
self.cot_token = cot_token
self.end_cot_token = end_cot_token
self.use_cot = use_cot
self.cot_ratio = cot_ratio
self.samples = []
self.skipped_count = 0
self._load_data(file_path)
def _load_data(self, file_path: str):
print(f"Loading dataset from {file_path}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
if not line.strip():
continue
data = json.loads(line)
if not all(key in data for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} missing required fields, skipping...")
continue
if not all(isinstance(data[key], str) for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} has invalid data types, skipping...")
continue
if not all(data[key].strip() for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} has empty fields, skipping...")
continue
self.samples.append(data)
except json.JSONDecodeError as e:
self.skipped_count += 1
print(f"Warning: Line {line_num} is not valid JSON ({e}), skipping...")
continue
except Exception as e:
self.skipped_count += 1
print(f"Warning: Line {line_num} caused error ({e}), skipping...")
continue
print(f"Loaded {len(self.samples)} valid samples")
print(f"Skipped {self.skipped_count} malformed rows")
if len(self.samples) == 0:
raise ValueError("No valid samples loaded from dataset!")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
# Build prompt (what the model sees as input) and completion (what it must learn to generate)
if self.use_cot:
if random.random() < self.cot_ratio:
prompt = f"{sample['input']} <cot>"
completion = f" {sample['cot']} {self.end_cot_token} {sample['output']}"
else:
prompt = f"{sample['input']}"
completion = f" {sample['output']}"
else:
prompt = f"{sample['input']}"
completion = f" {sample['output']}"
# Encode prompt alone to know exactly where it ends in token space
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=True)
prompt_len = len(prompt_ids)
# Encode full sequence with truncation
full_ids = self.tokenizer.encode(
prompt + completion,
max_length=self.max_length,
truncation=True,
padding=False,
return_tensors=None,
add_special_tokens=True,
)
# FIX: mask prompt tokens with -100 so CrossEntropyLoss ignores them.
# Only the completion (CoT reasoning + answer) contributes to the gradient.
labels = [-100] * prompt_len + full_ids[prompt_len:]
# Guard: if truncation cut into the prompt, clamp to full_ids length
labels = labels[:len(full_ids)]
return {
'input_ids': torch.tensor(full_ids, dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.long),
'length': len(full_ids),
'cot_length': len(self.tokenizer.encode(
sample['cot'], add_special_tokens=False
)) if self.use_cot else 0,
}
def collate_fn_with_packing(batch, pad_token_id=0):
# FIX: use pre-masked labels from dataset instead of copying input_ids.
# Padding positions use -100 so they are also ignored by the loss.
batch = sorted(batch, key=lambda x: x['length'], reverse=True)
max_length = max(item['length'] for item in batch)
input_ids_batch = []
attention_mask_batch = []
labels_batch = []
for item in batch:
input_ids = item['input_ids']
labels = item['labels']
length = item['length']
pad_len = max_length - length
input_ids_padded = F.pad(input_ids, (0, pad_len), value=pad_token_id)
labels_padded = F.pad(labels, (0, pad_len), value=-100)
attention_mask = torch.cat([
torch.ones(length, dtype=torch.long),
torch.zeros(pad_len, dtype=torch.long)
])
input_ids_batch.append(input_ids_padded)
attention_mask_batch.append(attention_mask)
labels_batch.append(labels_padded)
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'labels': torch.stack(labels_batch),
}
# ============================================================================
# CURRICULUM LEARNING
# ============================================================================
def create_curriculum_datasets(dataset, stages=[256, 512, 1024], use_simple=False, skip_stages=0):
"""
Build per-stage datasets.
FIX (simple curriculum): skipped stages now use `continue` so they are
never built — no tokenizer.encode() calls wasted on stages that get thrown
away. The slice at the bottom is moved INSIDE the else-branch only, so
the simple-curriculum list is never accidentally emptied.
"""
datasets = []
if use_simple:
for i, max_len in enumerate(stages):
# FIX: skip immediately — no encoding, no dataset build
if i < skip_stages:
print(f"[SKIP] Curriculum stage {max_len}: skipped (not built)")
continue
filtered_samples = [
s for s in dataset.samples
if len(dataset.tokenizer.encode(
f"{s['input']} {dataset.cot_token} {s['cot']} {dataset.end_cot_token} {s['output']}"
)) <= max_len
]
stage_dataset = _build_stage_dataset(dataset, filtered_samples, max_len, dataset.cot_ratio)
datasets.append(stage_dataset)
print(f"Curriculum stage {max_len}: {len(filtered_samples)} samples")
# NOTE: no slice here — skipped stages were never added to datasets
else:
print("\n" + "="*80)
print("3-STAGE REASONING CURRICULUM")
if skip_stages > 0:
print(f" (Skipping first {skip_stages} stage(s) — continue-train mode)")
print("="*80)
stage_configs = [
{
'name': 'Stage 1: Basic Q&A (short, no reasoning)',
'max_len': 384,
'cot_ratio': 0.0,
'filter': lambda s: len(dataset.tokenizer.encode(
f"{s['input']} {s['output']}"
)) <= 384
},
{
'name': 'Stage 2: Learning Reasoning (medium, 50% CoT)',
'max_len': 512,
'cot_ratio': 0.5,
'filter': lambda s: True
},
{
'name': 'Stage 3: Full Reasoning (all, 100% CoT)',
'max_len': 1024,
'cot_ratio': 1.0,
'filter': lambda s: True
}
]
for idx, stage_config in enumerate(stage_configs):
filtered_samples = [s for s in dataset.samples if stage_config['filter'](s)]
stage_dataset = _build_stage_dataset(
dataset, filtered_samples,
stage_config['max_len'], stage_config['cot_ratio']
)
datasets.append(stage_dataset)
skipped = idx < skip_stages
prefix = " [SKIP] " if skipped else " "
print(f"{prefix}{stage_config['name']}")
print(f" {'(skipped)' if skipped else ''} Samples: {len(filtered_samples)}")
print(f" {'(skipped)' if skipped else ''} Max length: {stage_config['max_len']}")
print(f" {'(skipped)' if skipped else ''} CoT ratio: {stage_config['cot_ratio']*100:.0f}%")
print("="*80 + "\n")
# FIX: slice lives inside else-branch only, safe for 3-stage mode
if skip_stages > 0:
datasets = datasets[skip_stages:]
return datasets
def _build_stage_dataset(base_dataset, samples, max_len, cot_ratio):
"""Helper: create a shallow-copy stage dataset from a list of samples."""
stage = IndonesianCoTDataset.__new__(IndonesianCoTDataset)
stage.tokenizer = base_dataset.tokenizer
stage.max_length = max_len
stage.cot_token = base_dataset.cot_token
stage.end_cot_token = base_dataset.end_cot_token
stage.samples = samples
stage.skipped_count = 0
stage.use_cot = base_dataset.use_cot
stage.cot_ratio = cot_ratio
return stage
# ============================================================================
# LEARNING-RATE SCHEDULERS
# ============================================================================
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def get_continue_schedule(optimizer, num_training_steps: int, min_fraction: float = 0.1):
"""
Schedule for --continue-train without saved optimizer state.
Starts at target LR immediately and decays gently via cosine to
min_fraction x LR by the end of training.
"""
def lr_lambda(step):
progress = float(step) / float(max(1, num_training_steps))
cosine_val = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0)))
return min_fraction + (1.0 - min_fraction) * cosine_val
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
class PlateauLRGuard:
"""
Wraps any LambdaLR scheduler and applies an extra multiplicative penalty
when perplexity has not improved for `patience` consecutive checks.
"""
def __init__(self, scheduler, patience=3, factor=0.5, min_delta=0.02):
self.scheduler = scheduler
self.patience = patience
self.factor = factor
self.min_delta = min_delta
self._best = float('inf')
self._no_improve = 0
self._penalty = 1.0
def step(self, perplexity: float):
relative_improvement = (self._best - perplexity) / max(self._best, 1e-8)
if relative_improvement > self.min_delta:
self._best = perplexity
self._no_improve = 0
else:
self._no_improve += 1
if self._no_improve >= self.patience:
self._penalty *= self.factor
self._no_improve = 0
new_lr = self.scheduler.get_last_lr()[0] * self.factor
print(f"\n[PlateauLRGuard] No improvement for {self.patience} checks. "
f"Reducing LR by {self.factor:.2f}x -> {new_lr:.2e}")
for pg in self.scheduler.optimizer.param_groups:
pg['lr'] *= self.factor
return True
return False
def get_penalty(self):
return self._penalty
# ============================================================================
# TRAINING UTILITIES
# ============================================================================
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================================
# ELASTIC WEIGHT CONSOLIDATION (EWC)
# ============================================================================
class EWC:
"""
Elastic Weight Consolidation — prevents catastrophic forgetting during finetuning.
"""
def __init__(self, model, dataloader, device, n_samples: int = 2000):
self.device = device
self.n_samples = n_samples
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad
}
self.fisher = self._compute_fisher(model, dataloader)
def _compute_fisher(self, model, dataloader):
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
model.eval()
seen = 0
for batch in dataloader:
if seen >= self.n_samples:
break
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
model.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
outputs["loss"].backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.detach().pow(2)
seen += input_ids.size(0)
for n in fisher:
fisher[n] /= max(1, seen)
model.train()
return fisher
def penalty(self, model) -> torch.Tensor:
loss = torch.tensor(0.0, device=self.device)
for n, p in model.named_parameters():
if p.requires_grad and n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return loss * 0.5
# ============================================================================
# TRAINING LOOP
# ============================================================================
def train_model(
model: IndonesianLLM,
train_dataset: IndonesianCoTDataset,
config: TrainingConfig,
device: torch.device,
use_simple_curriculum: bool = False,
is_continue: bool = False,
skip_curriculum_stages: int = 0,
ewc: "EWC | None" = None,
):
"""Main training loop."""
print("\n" + "="*80)
print("TRAINING CONFIGURATION" + (" [CONTINUE MODE]" if is_continue else ""))
print("="*80)
print(f"Model parameters: {model.count_parameters():,}")
print(f"Dataset size: {len(train_dataset)}")
print(f"Batch size: {config.batch_size}")
print(f"Gradient accumulation steps: {config.gradient_accumulation_steps}")
print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"Learning rate: {config.learning_rate}")
print(f"Max sequence length: {config.max_seq_length}")
print(f"Number of epochs: {config.num_epochs}")
print(f"Mixed precision: {config.use_fp16}")
if is_continue:
print(f"Skipping curriculum stages: {skip_curriculum_stages}")
print(f"Plateau patience: {config.plateau_patience}")
print(f"Plateau LR factor: {config.plateau_factor}")
if ewc is not None:
print(f"EWC lambda: {config.ewc_lambda} (anti-forgetting active)")
print("="*80 + "\n")
model.to(device)
model.train()
curriculum_datasets = create_curriculum_datasets(
train_dataset,
config.curriculum_stages,
use_simple=use_simple_curriculum,
skip_stages=skip_curriculum_stages,
)
if len(curriculum_datasets) == 0:
print("ERROR: No curriculum stages to train on. Check --skip-stages value.")
return model
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_epsilon,
weight_decay=config.weight_decay
)
# Calculate total steps
total_steps = 0
for ds in curriculum_datasets:
steps_per_epoch = max(1, len(ds) // (config.batch_size * config.gradient_accumulation_steps))
total_steps += steps_per_epoch * config.num_epochs
if total_steps == 0:
total_steps = 1
# LR Scheduler
if is_continue:
scheduler = get_continue_schedule(optimizer, num_training_steps=total_steps)
plateau_guard = PlateauLRGuard(
scheduler,
patience=config.plateau_patience,
factor=config.plateau_factor,
min_delta=config.plateau_min_delta,
)
print(f"[Scheduler] Continue-train: flat cosine decay from {config.learning_rate:.2e}")
else:
if config.lr_scheduler_type == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
else:
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
plateau_guard = None
scaler = torch.cuda.amp.GradScaler() if config.use_fp16 and torch.cuda.is_available() else None
global_step = 0
perplexity_history = []
for stage_idx, stage_dataset in enumerate(curriculum_datasets):
actual_stage = stage_idx + skip_curriculum_stages
print(f"\n{'='*80}")
print(f"CURRICULUM STAGE {actual_stage + 1}/{len(curriculum_datasets) + skip_curriculum_stages} "
f"(running {stage_idx + 1} of {len(curriculum_datasets)})")
print(f"Max sequence length: {stage_dataset.max_length}")
print(f"Samples: {len(stage_dataset)}")
if hasattr(stage_dataset, 'cot_ratio'):
print(f"CoT ratio: {stage_dataset.cot_ratio:.0%}")
print(f"{'='*80}\n")
dataloader = DataLoader(
stage_dataset,
batch_size=config.batch_size,
shuffle=True,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0,
pin_memory=True if torch.cuda.is_available() else False
)
for epoch in range(config.num_epochs):
print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
epoch_loss = 0.0
optimizer.zero_grad()
for step, batch in enumerate(dataloader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
if scaler is not None:
with torch.cuda.amp.autocast():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
task_loss = outputs['loss']
if ewc is not None:
task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
loss = task_loss / config.gradient_accumulation_steps
scaler.scale(loss).backward()
else:
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
task_loss = outputs['loss']
if ewc is not None:
task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
loss = task_loss / config.gradient_accumulation_steps
loss.backward()
epoch_loss += loss.item() * config.gradient_accumulation_steps
if (step + 1) % config.gradient_accumulation_steps == 0:
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
if scaler is not None:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % config.logging_steps == 0:
avg_loss = epoch_loss / (step + 1)
current_lr = scheduler.get_last_lr()[0]
print(f"Step {global_step:>6} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
avg_epoch_loss = epoch_loss / max(1, len(dataloader))
perplexity = math.exp(min(avg_epoch_loss, 20))
perplexity_history.append(perplexity)
print(f"Epoch {epoch + 1} completed | Avg Loss: {avg_epoch_loss:.4f} "
f"| Perplexity: {perplexity:.2f}")
if plateau_guard is not None:
plateau_guard.step(perplexity)
print("\n" + "="*80)
print("TRAINING COMPLETED")
if perplexity_history:
print(f"Final perplexity: {perplexity_history[-1]:.2f}")
print(f"Best perplexity: {min(perplexity_history):.2f}")
print("="*80 + "\n")
return model
# ============================================================================
# EVALUATION
# ============================================================================
def evaluate_model(model, dataset, device, batch_size=4):
model.eval()
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0
)
total_loss = 0.0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
total_loss += outputs['loss'].item() * input_ids.size(0)
total_samples += input_ids.size(0)
avg_loss = total_loss / max(1, total_samples)
perplexity = math.exp(min(avg_loss, 20))
print(f"\nEvaluation Results:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Perplexity: {perplexity:.2f}")
if perplexity < 5.0:
print("Status: Excellent")
elif perplexity < 10.0:
print("Status: Good")
elif perplexity < 20.0:
print("Status: Fair — try more epochs or lower LR")
else:
print("Status: Poor — check data quality or model config")
if avg_loss < 0.5:
print("Warning: Very low loss might indicate overfitting")
return {"loss": avg_loss, "perplexity": perplexity}
# ============================================================================
# GENERATION
# ============================================================================
def generate_text(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9,
device: torch.device = torch.device('cpu')
):
model.eval()
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated_ids = input_ids.clone()
eos_token_id = tokenizer.eos_token_id
if eos_token_id is None:
eos_token_id = tokenizer.sep_token_id
if eos_token_id is None:
eos_token_id = 2
stop_tokens = {eos_token_id, tokenizer.pad_token_id}
if tokenizer.sep_token_id is not None:
stop_tokens.add(tokenizer.sep_token_id)
repetition_buffer = []
with torch.no_grad():
for step in range(max_new_tokens):
outputs = model(input_ids=generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :] / max(temperature, 0.1)
if len(repetition_buffer) > 10:
for token in set(repetition_buffer[-10:]):
if token in stop_tokens:
continue
next_token_logits[0, token] -= 2.0
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(
next_token_logits, min(top_k, next_token_logits.size(-1))
)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() in stop_tokens:
break
repetition_buffer.append(next_token.item())
if len(repetition_buffer) > 20:
repetition_buffer.pop(0)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
max_ctx = model.config.max_position_embeddings
if generated_ids.size(1) > max_ctx:
generated_ids = generated_ids[:, -max_ctx:]
if step > 10:
decoded = tokenizer.decode(
generated_ids[0][input_ids.size(1):], skip_special_tokens=False)
if '\n\n' in decoded or 'User:' in decoded or 'Assistant:' in decoded[-20:]:
break
return tokenizer.decode(generated_ids[0], skip_special_tokens=False)
# ============================================================================
# INTERACTIVE CHAT
# ============================================================================
def _clean_response(response: str) -> str:
import re
if "<cot>" in response and "</cot>" in response:
response = response.split("</cot>", 1)[-1]
elif "<cot>" in response:
response = response.split("<cot>", 1)[0]
response = re.sub(r'\[\w+\]', '', response)
response = re.sub(r'<[^>]+>', '', response)
for marker in [
"user :", "user:", "User :", "User:",
"assistant :", "assistant:", "Assistant :", "Assistant:",
"memahami permintaan", "jawaban singkat", "penjelasan harus",
"\n\n",
]:
if marker in response:
response = response.split(marker)[0]
response = re.sub(r'^[\s:!,\.\-|\[\]]+', '', response)
response = re.sub(r' {2,}', ' ', response).strip()
return response
def _extract_thinking(raw: str) -> tuple:
import re
raw = re.sub(r'\[\w+\]', '', raw)
if "</cot>" in raw:
thinking_raw, answer_raw = raw.split("</cot>", 1)
else:
thinking_raw, answer_raw = raw, ""
thinking = thinking_raw.strip()
for marker in ["user :", "user:", "memahami permintaan", "\n\n"]:
if marker in thinking:
thinking = thinking.split(marker)[0]
thinking = re.sub(r'<[^>]+>', '', thinking).strip()
answer = _clean_response(answer_raw)
return thinking, answer
def interactive_chat(model, tokenizer, device,
system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia."):
print("\n" + "="*80)
print("INDONESIAN LLM — INTERACTIVE CHAT")
print("="*80)
print("Commands: 'exit'/'quit' | 'clear' | 'think' (toggle reasoning display)")
print(f"Persona : {system_prompt}")
print("="*80 + "\n")
model.eval()
show_thinking = False
import time
torch.manual_seed(int(time.time()) % 100000)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(int(time.time()) % 100000)
while True:
try:
user_input = input("\nYou: ").strip()
if not user_input:
continue
if user_input.lower() in ['exit', 'quit', 'keluar']:
print("\nGoodbye!")
break
if user_input.lower() in ['clear', 'bersihkan']:
print("\nConversation cleared")
continue
if user_input.lower() == 'think':
show_thinking = not show_thinking
print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}")
continue
prompt = f"{user_input} <cot>"
max_tokens = 250
print("\nA:", end=" ", flush=True)
full_response = generate_text(
model=model, tokenizer=tokenizer, prompt=prompt,
max_new_tokens=max_tokens, temperature=0.9,
top_k=50, top_p=0.92, device=device
)
response = full_response[len(prompt):].strip()
thinking, answer = _extract_thinking(response)
if show_thinking and thinking:
print(f"[Thinking: {thinking}]")
final = answer if answer else _clean_response(response)
if not final or len(final) < 3:
final = "Maaf, saya tidak mengerti. Bisa diulang?"
print(final)
except KeyboardInterrupt:
print("\n\nChat interrupted")
break
except Exception as e:
print(f"\nError: {e}")
# ============================================================================
# BENCHMARK
# ============================================================================
def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 20, verbose: bool = True):
import time
if dataset_path is None or not os.path.exists(dataset_path):
print(f"Dataset not found at: {dataset_path}")
print("Pass --dataset path/to/your.jsonl to benchmark against your data.")
return
all_samples = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
d = json.loads(line)
if all(k in d for k in ['input', 'output']) and d['input'].strip() and d['output'].strip():
all_samples.append(d)
except Exception:
continue
if not all_samples:
print("No valid samples found in dataset.")
return
random.seed(int(time.time()))
samples = random.sample(all_samples, min(n, len(all_samples)))
model.eval()
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
print("\n" + "="*80)
print(f"BENCHMARK ({len(samples)} random samples from dataset)")
print("="*80)
results = []
ppl_loss = 0.0
ppl_toks = 0
for i, sample in enumerate(samples):
inp = sample['input'].strip()
expected = sample['output'].strip().lower()
prompt = f"{inp} <cot>"
full = generate_text(
model=model, tokenizer=tokenizer, prompt=prompt,
max_new_tokens=150, temperature=0.3,
top_k=20, top_p=0.9, device=device
)
raw = full[len(prompt):].strip()
_, answer = _extract_thinking(raw)
answer_lower = answer.lower()
passed = expected in answer_lower
if not passed:
exp_tokens = set(expected.split())
ans_tokens = set(answer_lower.split())
if exp_tokens:
overlap = len(exp_tokens & ans_tokens) / len(exp_tokens)
passed = overlap >= 0.5
results.append(passed)
with torch.no_grad():
ids = tokenizer.encode(
f"{inp} <cot> {sample['output']}",
return_tensors="pt"
).to(device)
if ids.size(1) >= 2:
out = model(input_ids=ids, labels=ids)
toks = ids.size(1) - 1
ppl_loss += out["loss"].item() * toks
ppl_toks += toks
if verbose:
status = "PASS" if passed else "FAIL"
print(f" [{status}] {inp[:60]}")
print(f" Expected : {sample['output'][:80]}")
print(f" Got : {answer[:80] if answer else '(no answer)'}")
total_pass = sum(results)
total = len(results)
overall = total_pass / total * 100
bar = "█" * int(overall / 10) + "░" * (10 - int(overall / 10))
ppl = math.exp(min(ppl_loss / max(1, ppl_toks), 20))
print("\n" + "-"*80)
print(f" SCORE {total_pass}/{total} ({overall:.1f}%) {bar}")
print(f" PERPLEXITY {ppl:.2f} (lower = better)")
print("="*80 + "\n")
return {"score": overall, "pass": total_pass, "total": total, "perplexity": ppl}
# ============================================================================
# MODEL SAVING AND LOADING
# ============================================================================
def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, save_path: str, use_fp16: bool = True):
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
state_dict = model.state_dict()
if use_fp16:
state_dict = {k: v.half() if v.dtype == torch.float32 else v
for k, v in state_dict.items()}
torch.save({
'model_state_dict': state_dict,
'config': config,
'tokenizer_name': tokenizer_name,
'model_params': model.count_parameters(),
'dtype': 'fp16' if use_fp16 else 'fp32',
}, save_path)
size_mb = os.path.getsize(save_path) / 1e6
print(f"\nModel saved to: {save_path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB)")
print(f"Parameters: {model.count_parameters():,}")
def load_model(load_path: str, device: torch.device):
if not os.path.exists(load_path):
raise FileNotFoundError(f"Model checkpoint not found: {load_path}")
print(f"Loading model from: {load_path}")
checkpoint = torch.load(load_path, map_location=device, weights_only=False)
config = checkpoint['config']
tokenizer_name = checkpoint['tokenizer_name']
dtype = checkpoint.get('dtype', 'fp32')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
model = IndonesianLLM(config)
state_dict = checkpoint['model_state_dict']
if dtype == 'fp16':
state_dict = {k: v.float() if v.dtype == torch.float16 else v
for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(device)
extra = {'training_metadata': checkpoint.get('training_metadata', None)}
size_mb = os.path.getsize(load_path) / 1e6
print(f"Model loaded ({dtype}, {size_mb:.1f} MB) | "
f"Parameters: {checkpoint.get('model_params', model.count_parameters()):,}")
return model, tokenizer, config, extra
# ============================================================================
# MAIN
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Indonesian Conversational LLM — Train, Chat, Finetune, or Continue")
parser.add_argument('--train', action='store_true', help='Train model from scratch')
parser.add_argument('--chat', action='store_true', help='Interactive chat mode')
parser.add_argument('--finetune', action='store_true', help='Fine-tune on NEW data (lr/10)')
parser.add_argument('--continue-train', action='store_true',
help='Continue training on SAME data with proper LR re-warmup')
parser.add_argument('--inspect-data', action='store_true', help='Inspect dataset quality')
parser.add_argument('--benchmark', action='store_true', help='Run benchmark suite on a saved model')
parser.add_argument('--no-eval', action='store_true', help='Skip evaluation after training')
parser.add_argument('--grad-accum', type=int, default=None,
help='Override gradient accumulation steps')
parser.add_argument('--ewc-lambda', type=float, default=5000.0,
help='EWC penalty strength for --finetune (default 5000). 0 = disabled.')
parser.add_argument('--ewc-samples', type=int, default=2000,
help='Samples used to estimate Fisher Information for EWC (default 2000).')
parser.add_argument('--no-ewc', action='store_true',
help='Disable EWC during finetuning.')
parser.add_argument('--dataset', type=str, default='indonesian_cot_dataset.jsonl')
parser.add_argument('--model', type=str, default='indonesian_llm_model.pt')
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--max-length', type=int, default=512,
help='Max sequence length. Dataset max is 448 so 512 is optimal.')
parser.add_argument('--hidden-size', type=int, default=320)
parser.add_argument('--num-layers', type=int, default=16)
parser.add_argument('--num-heads', type=int, default=8)
parser.add_argument('--num-kv-heads', type=int, default=2)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--save-fp16', action='store_true', default=True)
parser.add_argument('--save-fp32', action='store_true')
parser.add_argument('--no-cot', action='store_true')
parser.add_argument('--use-cot', action='store_true', default=True)
parser.add_argument('--simple-curriculum', action='store_true')
parser.add_argument('--cot-ratio', type=float, default=1.0)
parser.add_argument('--system-prompt', type=str,
default='Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.')
parser.add_argument('--rewarm-steps', type=int, default=150)
parser.add_argument('--skip-stages', type=int, default=2,
help='Curriculum stages to skip in continue-train (default 2).')
parser.add_argument('--plateau-patience', type=int, default=3)
parser.add_argument('--no-restore-optimizer', action='store_true')
args = parser.parse_args()
if not any([args.train, args.chat, args.finetune, args.continue_train,
args.inspect_data, args.benchmark]):
parser.print_help()
print("\nError: Specify a mode: --train, --chat, --finetune, --continue-train, "
"--inspect-data, or --benchmark")
return
save_fp16 = not args.save_fp32
use_cot_training = not args.no_cot
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")
# ------------------------------------------------------------------
# INSPECT DATA
# ------------------------------------------------------------------
if args.inspect_data:
print("\nInspecting dataset...")
print("="*80)
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
dataset = IndonesianCoTDataset(
file_path=args.dataset, tokenizer=tokenizer,
max_length=args.max_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
)
print(f"\nDataset Statistics:")
print(f"Total samples: {len(dataset)}")
print(f"Skipped samples: {dataset.skipped_count}")
lengths = []
cot_lengths = []
for i in range(min(len(dataset), 1000)):
sample = dataset[i]
lengths.append(sample['length'])
cot_lengths.append(sample['cot_length'])
print(f"\nSequence Length Stats:")
print(f" Min: {min(lengths)}")
print(f" Max: {max(lengths)}")
print(f" Avg: {sum(lengths)/len(lengths):.1f}")
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
print(f"\nCoT Length Stats:")
print(f" Min: {min(cot_lengths)}")
print(f" Max: {max(cot_lengths)}")
print(f" Avg: {sum(cot_lengths)/len(cot_lengths):.1f}")
long_cot = sum(1 for x in cot_lengths if x > 50)
print(f" Samples with long CoT (>50 tokens): {long_cot} ({long_cot/len(cot_lengths)*100:.1f}%)")
print(f"\n{'='*80}")
print("Sample Examples (first 5):")
print("="*80)
for i in range(min(5, len(dataset.samples))):
s = dataset.samples[i]
print(f"\n--- Sample {i+1} ---")
print(f"Input: {s['input'][:100]}...")
print(f"CoT: {s['cot'][:150]}...")
print(f"Output: {s['output'][:100]}...")
print("\n" + "="*80)
print("Dataset Quality Checks:")
issues = []
for i, sample in enumerate(dataset.samples[:100]):
if len(sample['input']) < 10:
issues.append(f"Sample {i}: Input too short")
if len(sample['output']) < 10:
issues.append(f"Sample {i}: Output too short")
if len(sample['cot']) < 20:
issues.append(f"Sample {i}: CoT too short")
if sample['input'].lower() == sample['output'].lower():
issues.append(f"Sample {i}: Input == Output (copy)")
if issues:
print(f"\nFound {len(issues)} potential issues in first 100 samples:")
for issue in issues[:10]:
print(f" - {issue}")
if len(issues) > 10:
print(f" ... and {len(issues)-10} more")
else:
print("\nNo obvious issues detected in first 100 samples")
print("\n" + "="*80)
return
# ------------------------------------------------------------------
# CHAT
# ------------------------------------------------------------------
if args.chat:
print("\nStarting CHAT mode...")
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, config, _ = load_model(args.model, device)
interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt)
return
# ------------------------------------------------------------------
# BENCHMARK
# ------------------------------------------------------------------
if args.benchmark:
print("\nRunning benchmark...")
if not os.path.exists(args.model):
print(f"Error: Model not found: {args.model}")
return
model, tokenizer, config, _ = load_model(args.model, device)
run_benchmark(model, tokenizer, device, dataset_path=args.dataset)
return
# ------------------------------------------------------------------
# TRAIN FROM SCRATCH
# ------------------------------------------------------------------
if args.train:
print("\nStarting TRAINING mode (from scratch)...")
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
model_config = ModelConfig(
vocab_size=len(tokenizer),
hidden_size=args.hidden_size,
num_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_kv_heads,
intermediate_size=args.hidden_size * 3,
max_position_embeddings=2048,
attention_dropout=0.1,
residual_dropout=0.1,
tie_word_embeddings=True
)
model = IndonesianLLM(model_config)
print(f"Model parameters: {model.count_parameters():,}")
_ga = args.grad_accum if args.grad_accum else 32
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=args.lr,
warmup_steps=500,
use_fp16=torch.cuda.is_available(),
curriculum_stages=[128, 256, args.max_length]
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum
)
if not args.no_eval:
evaluate_model(model, dataset, device)
save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16)
test_prompts = [
"Berapa hasil dari 1+1?",
"Jelaskan cara kerja komputer",
"Bagaimana cara membuat kopi yang enak?"
]
print("\n" + "="*80)
print("GENERATION TEST")
print("="*80 + "\n")
for prompt in test_prompts:
print(f"Prompt: {prompt}")
generated = generate_text(model, tokenizer, prompt, max_new_tokens=150, device=device)
print(f"Generated: {generated}\n")
print("-" * 80 + "\n")
print(f"\nTo chat: python {__file__} --chat --model {args.model}")
# ------------------------------------------------------------------
# FINETUNE (new data, lr/10)
# ------------------------------------------------------------------
if args.finetune:
print("\nStarting FINETUNE mode (for NEW data)...")
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, model_config, extra = load_model(args.model, device)
_ga = args.grad_accum if args.grad_accum else 32
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=args.lr / 10,
warmup_steps=100,
use_fp16=torch.cuda.is_available(),
curriculum_stages=[128, 256, args.max_length]
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
print(f"\nStarting fine-tuning with LR={train_config.learning_rate:.2e}...")
ewc_obj = None
if not args.no_ewc and args.ewc_lambda > 0:
print(f"\nComputing EWC Fisher Information "
f"(lambda={args.ewc_lambda}, samples={args.ewc_samples})...")
print(" This takes ~1-2 min on T4. Prevents forgetting old training data.")
old_dataset = IndonesianCoTDataset(
file_path=args.dataset, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
old_loader = DataLoader(
old_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0
)
train_config.ewc_lambda = args.ewc_lambda
train_config.ewc_samples = args.ewc_samples
ewc_obj = EWC(model, old_loader, device, n_samples=args.ewc_samples)
print(f" Fisher computed. EWC penalty will be added during training.")
else:
print(" EWC disabled — model may forget previous training.")
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum,
ewc=ewc_obj,
)
if not args.no_eval:
evaluate_model(model, dataset, device)
finetuned_path = args.model.replace('.pt', '_finetuned.pt')
save_model(model, model_config, "indolem/indobert-base-uncased",
finetuned_path, use_fp16=save_fp16)
print(f"\nFine-tuning completed. Model saved to: {finetuned_path}")
print(f"To chat: python {__file__} --chat --model {finetuned_path}")
# ------------------------------------------------------------------
# CONTINUE TRAINING
# ------------------------------------------------------------------
if args.continue_train:
print("\nStarting CONTINUE-TRAIN mode...")
print("="*80)
print("NOTE: For a 15-30M param model on Indonesian CoT, perplexity 5-15 is normal.")
print("Key improvements:")
print(" 1. Early curriculum stages skipped → no wasted time on short sequences")
print(" 2. Plateau LR reduction → auto-halves LR if perplexity stalls")
print(" 3. Micro LR on cold Adam → no destructive spike at restart")
print("="*80)
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, model_config, extra = load_model(args.model, device)
effective_lr = args.lr * 0.05
print(f"\nContinue-train LR: {args.lr:.2e} x 0.05 = {effective_lr:.2e}")
print(f" Adam cold-starts — micro LR prevents overshooting the trained minimum.")
curriculum_stages = [192, 320, args.max_length]
if args.simple_curriculum:
effective_skip = len(curriculum_stages) - 1
else:
effective_skip = args.skip_stages
print(f"Skipping first {effective_skip} curriculum stage(s) — "
f"training on full-length data only.")
_ga = args.grad_accum if args.grad_accum else 32
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=effective_lr,
warmup_steps=0,
use_fp16=torch.cuda.is_available(),
curriculum_stages=curriculum_stages,
skip_curriculum_stages=effective_skip,
plateau_patience=2,
plateau_factor=0.5,
plateau_min_delta=0.02,
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum,
is_continue=True,
skip_curriculum_stages=effective_skip,
)
if not args.no_eval:
print("\nEvaluating continued model...")
evaluate_model(model, dataset, device)
save_model(model, model_config, "indolem/indobert-base-uncased",
args.model, use_fp16=save_fp16)
print(f"\nContinued training completed.")
print(f"Model saved to: {args.model}")
print(f"To chat: python {__file__} --chat --model {args.model}")
if __name__ == "__main__":
main()