FaiziRBLX commited on
Commit
7df6753
·
verified ·
1 Parent(s): 17776f3

Update best.py

Browse files
Files changed (1) hide show
  1. best.py +418 -327
best.py CHANGED
@@ -5,37 +5,43 @@ Trained from scratch with Chain-of-Thought reasoning capability
5
  Architecture: Decoder-only transformer with GQA, RoPE, SwiGLU, RMSNorm, KV-Cache
6
  Target: 15M-30M parameters, optimized for Google Colab Free tier
7
 
8
- FIXES vs original:
9
- - KV cache for O(n) inference instead of O(n²)
10
- - RoPE broadcast shape corrected (explicit unsqueeze)
11
- - Label smoothing fixed: custom impl that respects ignore_index=-100
12
- - Depth-scaled weight init for residual branches (o_proj, down_proj)
13
- - Per-token loss tracking (not per-batch avg biased by padding)
14
- - Repetition penalty fixed: proper logit division (not magic subtract)
15
- - Mixed precision updated to PyTorch 2.x API (torch.amp.*)
16
- - EOS token appended to training completions so model learns to stop
17
- - Intermediate size computed correctly from hidden_size
18
- - Gradient norm logged every step
19
- - _clean_response made robust to Indonesian "user" as a word
20
- - Causal mask uses float directly (no bool intermediate)
21
- - GQA skip repeat when groups==1
22
- - Vocab size set from tokenizer at build time, never hardcoded
 
 
 
 
23
  """
24
 
25
  import torch
26
  import torch.nn as nn
27
  import torch.nn.functional as F
28
  from torch.utils.data import Dataset, DataLoader
 
29
  from transformers import AutoTokenizer
30
  import json
31
  import math
32
  import random
33
  import numpy as np
34
- from typing import Optional, Tuple, List, Dict
35
  from dataclasses import dataclass, field
36
  import warnings
37
  import argparse
38
  import os
 
39
 
40
  warnings.filterwarnings('ignore')
41
 
@@ -45,14 +51,11 @@ warnings.filterwarnings('ignore')
45
 
46
  @dataclass
47
  class ModelConfig:
48
- vocab_size: int = 32000 # set from len(tokenizer) at build time
49
  hidden_size: int = 384
50
  num_layers: int = 12
51
  num_attention_heads: int = 6
52
- num_key_value_heads: int = 2 # GQA
53
- # Stored as a plain int field — NEVER a @property — so pickle round-trips work.
54
- # 0 = unset (load_model will fill it from checkpoint weight shapes).
55
- # New training always passes this explicitly from len(tokenizer) / hidden_size.
56
  intermediate_size: int = 0
57
  max_position_embeddings: int = 2048
58
  rms_norm_eps: float = 1e-6
@@ -65,9 +68,10 @@ class ModelConfig:
65
  eos_token_id: int = 2
66
  tie_word_embeddings: bool = True
67
  label_smoothing: float = 0.1
 
 
68
 
69
  def __post_init__(self):
70
- # Set intermediate_size only when not already provided
71
  if self.intermediate_size <= 0:
72
  self.intermediate_size = self.hidden_size * 3
73
  assert self.hidden_size % self.num_attention_heads == 0, \
@@ -97,9 +101,9 @@ class TrainingConfig:
97
  lr_scheduler_type: str = "cosine"
98
 
99
  dropout: float = 0.1
100
-
101
- # FIX: torch.amp.* (PyTorch 2.x API)
102
  use_fp16: bool = True
 
 
103
 
104
  seed: int = 42
105
  logging_steps: int = 10
@@ -107,7 +111,6 @@ class TrainingConfig:
107
  save_steps: int = 500
108
 
109
  curriculum_stages: List[int] = None
110
- skip_curriculum_stages: int = 2
111
 
112
  plateau_patience: int = 3
113
  plateau_factor: float = 0.5
@@ -141,7 +144,7 @@ class RotaryEmbedding(nn.Module):
141
  t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
142
  freqs = torch.outer(t, self.inv_freq)
143
  emb = torch.cat((freqs, freqs), dim=-1)
144
- # FIX: store as [1, 1, T, D] so broadcast onto [B, H, T, D] is explicit and correct
145
  self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
146
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
147
 
@@ -160,14 +163,27 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
160
  return torch.cat((-x2, x1), dim=-1)
161
 
162
 
163
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
164
- # cos/sin: [1, 1, T, D] — broadcasts cleanly onto [B, H, T, D]
165
- if position_ids is not None:
166
- # For KV-cache decode: position_ids is [B, 1], pick specific positions
167
- cos = cos[:, :, position_ids, :].squeeze(2) # [B, 1, 1, D]
168
- sin = sin[:, :, position_ids, :].squeeze(2)
169
- q_embed = (q * cos) + (rotate_half(q) * sin)
170
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return q_embed, k_embed
172
 
173
 
@@ -221,6 +237,7 @@ class GroupedQueryAttention(nn.Module):
221
  self,
222
  hidden_states: torch.Tensor,
223
  attention_mask: Optional[torch.Tensor] = None,
 
224
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
225
  use_cache: bool = False,
226
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -231,43 +248,32 @@ class GroupedQueryAttention(nn.Module):
231
  key_states = self.k_proj(hidden_states)
232
  value_states = self.v_proj(hidden_states)
233
 
234
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
235
  key_states = key_states .view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
236
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
237
 
238
- # RoPE position offset accounts for cached prefix
239
- kv_seq_len = key_states.shape[-2]
240
- if past_key_value is not None:
241
- kv_seq_len += past_key_value[0].shape[-2]
242
 
243
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
244
 
245
- # For prefill (training / first forward): use full cos/sin slice
246
- # For decode (KV cache active): past_key_value holds the cached context
 
 
 
 
 
 
 
 
247
  if past_key_value is not None:
248
- # Decode step: only current token needs RoPE at position kv_seq_len-1
249
- offset = past_key_value[0].shape[-2]
250
- cos_q = cos[:, :, offset:offset + q_len, :]
251
- sin_q = sin[:, :, offset:offset + q_len, :]
252
- query_states = (query_states * cos_q) + (rotate_half(query_states) * sin_q)
253
- key_states = (key_states * cos_q) + (rotate_half(key_states) * sin_q)
254
- # Concat cached K, V
255
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
256
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
257
- else:
258
- # Prefill: full sequence RoPE
259
- cos_full = cos[:, :, :q_len, :]
260
- sin_full = sin[:, :, :q_len, :]
261
- query_states = (query_states * cos_full) + (rotate_half(query_states) * sin_full)
262
- key_states = (key_states * cos_full) + (rotate_half(key_states) * sin_full)
263
-
264
- # Store pre-expand KV in cache (shape [B, num_kv_heads, T, D]).
265
- # Must happen BEFORE repeat_interleave — otherwise cached keys have
266
- # num_heads channels instead of num_kv_heads, and every decode step
267
- # re-expands them again, corrupting attention.
268
  present_kv = (key_states, value_states) if use_cache else None
269
 
270
- # Expand KV heads for full attention computation
271
  if self.num_key_value_groups > 1:
272
  key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
273
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
@@ -296,8 +302,6 @@ class SwiGLUMLP(nn.Module):
296
  def __init__(self, config: ModelConfig):
297
  super().__init__()
298
  self.hidden_size = config.hidden_size
299
- # Read intermediate_size defensively: if somehow 0 or negative (e.g. old
300
- # unpickled config that missed __post_init__), fall back to hidden * 3.
301
  inter = getattr(config, 'intermediate_size', 0)
302
  if not isinstance(inter, int) or inter <= 0:
303
  inter = self.hidden_size * 3
@@ -328,21 +332,23 @@ class DecoderLayer(nn.Module):
328
  self,
329
  hidden_states: torch.Tensor,
330
  attention_mask: Optional[torch.Tensor] = None,
 
331
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
332
  use_cache: bool = False,
333
  ):
334
- residual = hidden_states
335
  hidden_states = self.input_layernorm(hidden_states)
336
  hidden_states, present_kv = self.self_attn(
337
  hidden_states,
338
  attention_mask=attention_mask,
 
339
  past_key_value=past_key_value,
340
  use_cache=use_cache,
341
  )
342
  hidden_states = self.residual_dropout(hidden_states)
343
  hidden_states = residual + hidden_states
344
 
345
- residual = hidden_states
346
  hidden_states = self.post_attention_layernorm(hidden_states)
347
  hidden_states = self.mlp(hidden_states)
348
  hidden_states = self.residual_dropout(hidden_states)
@@ -352,17 +358,14 @@ class DecoderLayer(nn.Module):
352
 
353
 
354
  # ============================================================================
355
- # CUSTOM LABEL SMOOTHING LOSS (FIX: respects ignore_index=-100)
356
  # ============================================================================
357
 
358
- class LabelSmoothingCrossEntropy(nn.Module):
359
- """
360
- Cross-entropy with label smoothing.
361
- Filters ignore_index=-100 first, then uses F.cross_entropy with smoothing.
362
- This keeps the exact same loss scale as the original nn.CrossEntropyLoss
363
- so the LR schedule pacing is unchanged.
364
- """
365
 
 
366
  def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
367
  super().__init__()
368
  self.vocab_size = vocab_size
@@ -370,15 +373,23 @@ class LabelSmoothingCrossEntropy(nn.Module):
370
  self.ignore_index = ignore_index
371
 
372
  def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
373
- # logits: [N, V] targets: [N]
374
- # F.cross_entropy with label_smoothing and ignore_index is correct in
375
- # PyTorch >= 1.10 — it does NOT distribute to ignored positions.
376
- return F.cross_entropy(
377
- logits,
378
- targets,
379
- ignore_index=self.ignore_index,
380
- label_smoothing=self.smoothing,
381
- )
 
 
 
 
 
 
 
 
382
 
383
 
384
  # ============================================================================
@@ -392,7 +403,8 @@ class IndonesianLLM(nn.Module):
392
  self.padding_idx = config.pad_token_id
393
  self.vocab_size = config.vocab_size
394
 
395
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
 
396
  self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)])
397
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398
 
@@ -400,21 +412,20 @@ class IndonesianLLM(nn.Module):
400
  nn.Linear(config.hidden_size, config.vocab_size, bias=False)
401
 
402
  self.loss_fn = LabelSmoothingCrossEntropy(
403
- vocab_size = config.vocab_size,
404
- smoothing = config.label_smoothing,
405
- ignore_index = -100,
406
  )
407
 
 
 
408
  self.apply(self._init_weights)
409
 
410
  def _init_weights(self, module):
411
  std = self.config.initializer_range
412
  if isinstance(module, nn.Linear):
413
- # FIX: depth-scaled init for residual output projections
414
- # (o_proj and down_proj feed directly into residual stream)
415
  name = getattr(module, '_layer_name', '')
416
  if name in ('o_proj', 'down_proj'):
417
- # Wang et al. 2021 / GPT-NeoX scaling
418
  scaled_std = std / math.sqrt(2 * self.config.num_layers)
419
  module.weight.data.normal_(mean=0.0, std=scaled_std)
420
  else:
@@ -427,7 +438,6 @@ class IndonesianLLM(nn.Module):
427
  module.weight.data[module.padding_idx].zero_()
428
 
429
  def _tag_projection_layers(self):
430
- """Tag o_proj and down_proj for depth-scaled init. Call before apply()."""
431
  for layer in self.layers:
432
  layer.self_attn.o_proj._layer_name = 'o_proj'
433
  layer.mlp.down_proj._layer_name = 'down_proj'
@@ -444,28 +454,61 @@ class IndonesianLLM(nn.Module):
444
  attention_mask: Optional[torch.Tensor] = None,
445
  batch_size: int = 1,
446
  ) -> torch.Tensor:
447
- """
448
- FIX: Build additive float causal mask directly instead of bool intermediate.
449
- Shape: [B, 1, T_q, T_kv]
450
- """
451
  total_len = past_len + seq_len
452
- # Full causal mask over [T_q, T_kv]
453
- causal = torch.full((seq_len, total_len), torch.finfo(dtype).min, device=device, dtype=dtype)
454
  mask_cond = torch.arange(total_len, device=device)
455
- causal.masked_fill_(mask_cond[None, :] <= (torch.arange(seq_len, device=device) + past_len)[:, None], 0.0)
 
 
 
456
  causal = causal[None, None, :, :].expand(batch_size, 1, seq_len, total_len)
457
 
458
  if attention_mask is not None:
459
- # attention_mask: [B, T_kv] — 1 = keep, 0 = mask out
460
  pad_mask = (1.0 - attention_mask[:, None, None, :].float()) * torch.finfo(dtype).min
461
  causal = causal + pad_mask
462
 
463
  return causal
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  def forward(
466
  self,
467
  input_ids: torch.Tensor,
468
  attention_mask: Optional[torch.Tensor] = None,
 
469
  labels: Optional[torch.Tensor] = None,
470
  past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
471
  use_cache: bool = False,
@@ -476,11 +519,15 @@ class IndonesianLLM(nn.Module):
476
 
477
  hidden_states = self.embed_tokens(input_ids)
478
 
479
- # Only build attention mask once (shared across all layers)
480
  if attention_mask is None:
481
  attention_mask = torch.ones(batch_size, past_len + seq_length,
482
  dtype=torch.long, device=input_ids.device)
483
 
 
 
 
 
 
484
  causal_mask = self._make_causal_mask(
485
  seq_len=seq_length,
486
  past_len=past_len,
@@ -494,11 +541,13 @@ class IndonesianLLM(nn.Module):
494
 
495
  for i, decoder_layer in enumerate(self.layers):
496
  pkv = past_key_values[i] if past_key_values is not None else None
497
- hidden_states, present_kv = decoder_layer(
 
498
  hidden_states,
499
- attention_mask=causal_mask,
500
- past_key_value=pkv,
501
- use_cache=use_cache,
 
502
  )
503
  if use_cache:
504
  present_key_values.append(present_kv)
@@ -515,9 +564,9 @@ class IndonesianLLM(nn.Module):
515
  loss = self.loss_fn(shift_logits, shift_labels)
516
 
517
  return {
518
- "loss": loss,
519
- "logits": logits,
520
- "past_key_values": present_key_values,
521
  }
522
 
523
  def count_parameters(self) -> int:
@@ -550,7 +599,6 @@ class IndonesianCoTDataset(Dataset):
550
  self.skipped_count = 0
551
  self._load_data(file_path)
552
 
553
- # FIX: get EOS string once, reuse everywhere
554
  @property
555
  def _eos(self) -> str:
556
  return self.tokenizer.eos_token or "</s>"
@@ -558,7 +606,7 @@ class IndonesianCoTDataset(Dataset):
558
  def _load_data(self, file_path: str):
559
  print(f"Loading dataset from {file_path}...")
560
  with open(file_path, 'r', encoding='utf-8') as f:
561
- for line_num, line in enumerate(f, 1):
562
  try:
563
  if not line.strip():
564
  continue
@@ -585,10 +633,8 @@ class IndonesianCoTDataset(Dataset):
585
  def __getitem__(self, idx):
586
  sample = self.samples[idx]
587
 
588
- # Build prompt / completion split
589
  if self.use_cot and random.random() < self.cot_ratio:
590
  prompt = f"{sample['input']} {self.cot_token}"
591
- # FIX: append EOS so the model learns when to stop
592
  completion = f" {sample['cot']} {self.end_cot_token} {sample['output']}{self._eos}"
593
  else:
594
  prompt = f"{sample['input']}"
@@ -605,7 +651,6 @@ class IndonesianCoTDataset(Dataset):
605
  add_special_tokens=True,
606
  )
607
 
608
- # Mask prompt tokens so only completion contributes to loss
609
  labels = [-100] * min(prompt_len, len(full_ids)) + full_ids[prompt_len:]
610
  labels = labels[:len(full_ids)]
611
 
@@ -643,12 +688,10 @@ def collate_fn_with_packing(batch, pad_token_id: int = 0):
643
 
644
 
645
  # ============================================================================
646
- # PER-TOKEN LOSS TRACKING (FIX: don't average over padding)
647
  # ============================================================================
648
 
649
  class TokenLossAccumulator:
650
- """Track loss and token count separately so perplexity is unbiased."""
651
-
652
  def __init__(self):
653
  self.total_loss = 0.0
654
  self.total_tokens = 0
@@ -689,13 +732,16 @@ def _build_stage_dataset(base: IndonesianCoTDataset, samples, max_len: int, cot_
689
  def create_curriculum_datasets(dataset, stages=None, use_simple=False, skip_stages=0):
690
  if stages is None:
691
  stages = [256, 512, 1024]
692
- datasets = []
 
 
 
 
 
693
 
694
  if use_simple:
 
695
  for i, max_len in enumerate(stages):
696
- if i < skip_stages:
697
- print(f"[SKIP] Curriculum stage {max_len}")
698
- continue
699
  filtered = [
700
  s for s in dataset.samples
701
  if len(dataset.tokenizer.encode(
@@ -703,37 +749,26 @@ def create_curriculum_datasets(dataset, stages=None, use_simple=False, skip_stag
703
  )) <= max_len
704
  ]
705
  datasets.append(_build_stage_dataset(dataset, filtered, max_len, dataset.cot_ratio))
706
- print(f"Curriculum stage {max_len}: {len(filtered)} samples")
 
707
  else:
708
- print("\n" + "=" * 80)
709
- print("3-STAGE REASONING CURRICULUM")
710
- if skip_stages > 0:
711
- print(f" (Skipping first {skip_stages} stage(s))")
712
- print("=" * 80)
713
-
714
  stage_configs = [
715
  {'name': 'Stage 1: Basic Q&A (no CoT)', 'max_len': 384, 'cot_ratio': 0.0},
716
  {'name': 'Stage 2: Learning Reasoning (50% CoT)', 'max_len': 512, 'cot_ratio': 0.5},
717
  {'name': 'Stage 3: Full Reasoning (100% CoT)', 'max_len': 1024, 'cot_ratio': 1.0},
718
  ]
719
-
720
  for idx, sc in enumerate(stage_configs):
721
- filtered = dataset.samples # all samples for stages 2+
722
- if idx == 0:
723
- filtered = [
724
- s for s in dataset.samples
725
- if len(dataset.tokenizer.encode(f"{s['input']} {s['output']}")) <= sc['max_len']
726
- ]
727
  datasets.append(_build_stage_dataset(dataset, filtered, sc['max_len'], sc['cot_ratio']))
728
- skipped = idx < skip_stages
729
- tag = " [SKIP]" if skipped else ""
730
  print(f" {sc['name']}{tag} | samples={len(filtered)} | CoT={sc['cot_ratio']:.0%}")
731
 
732
- print("=" * 80 + "\n")
733
- if skip_stages > 0:
734
- datasets = datasets[skip_stages:]
735
-
736
- return datasets
737
 
738
 
739
  # ============================================================================
@@ -805,6 +840,21 @@ def set_seed(seed: int):
805
  torch.backends.cudnn.benchmark = False
806
 
807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
  # ============================================================================
809
  # ELASTIC WEIGHT CONSOLIDATION (EWC)
810
  # ============================================================================
@@ -817,6 +867,11 @@ class EWC:
817
  self.fisher = self._compute_fisher(model, dataloader)
818
 
819
  def _compute_fisher(self, model, dataloader):
 
 
 
 
 
820
  fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
821
  model.eval()
822
  seen = 0
@@ -825,14 +880,29 @@ class EWC:
825
  break
826
  input_ids = batch["input_ids"] .to(self.device)
827
  attention_mask = batch["attention_mask"] .to(self.device)
828
- labels = batch["labels"] .to(self.device)
829
  model.zero_grad()
830
- out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
831
- out["loss"].backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  for n, p in model.named_parameters():
833
  if p.requires_grad and p.grad is not None:
834
  fisher[n] += p.grad.detach().pow(2)
835
  seen += input_ids.size(0)
 
836
  for n in fisher:
837
  fisher[n] /= max(1, seen)
838
  model.train()
@@ -872,9 +942,14 @@ def train_model(
872
  print(f" Max seq length: {config.max_seq_length}")
873
  print(f" Epochs: {config.num_epochs}")
874
  print(f" Mixed precision: {config.use_fp16}")
 
875
  print(f" EWC: {'enabled (lambda=' + str(config.ewc_lambda) + ')' if ewc else 'disabled'}")
876
  print("=" * 80 + "\n")
877
 
 
 
 
 
878
  model.to(device)
879
  model.train()
880
 
@@ -886,7 +961,7 @@ def train_model(
886
  )
887
 
888
  if not curriculum_datasets:
889
- print("ERROR: No curriculum stages. Check --skip-stages.")
890
  return model
891
 
892
  optimizer = torch.optim.AdamW(
@@ -902,7 +977,6 @@ def train_model(
902
  for ds in curriculum_datasets
903
  ) or 1
904
 
905
- # FIX: use torch.amp.* (PyTorch 2.x API, not deprecated cuda.amp.*)
906
  use_amp = config.use_fp16 and device.type == 'cuda'
907
  scaler = torch.amp.GradScaler('cuda') if use_amp else None
908
 
@@ -931,13 +1005,9 @@ def train_model(
931
  f"n={len(stage_dataset)} | CoT={getattr(stage_dataset, 'cot_ratio', '?'):.0%}")
932
  print(f"{'=' * 80}\n")
933
 
934
- dataloader = DataLoader(
935
- stage_dataset,
936
- batch_size=config.batch_size,
937
- shuffle=True,
938
- collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
939
- num_workers=0,
940
- pin_memory=(device.type == 'cuda'),
941
  )
942
 
943
  for epoch in range(config.num_epochs):
@@ -945,6 +1015,10 @@ def train_model(
945
  acc = TokenLossAccumulator()
946
  optimizer.zero_grad()
947
 
 
 
 
 
948
  for step, batch in enumerate(dataloader):
949
  input_ids = batch['input_ids'] .to(device)
950
  attention_mask = batch['attention_mask'] .to(device)
@@ -955,25 +1029,35 @@ def train_model(
955
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
956
  task_loss = outputs['loss']
957
  if ewc is not None:
958
- task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
959
- loss = task_loss / config.gradient_accumulation_steps
 
 
 
 
 
 
960
  scaler.scale(loss).backward()
961
  else:
962
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
963
  task_loss = outputs['loss']
964
  if ewc is not None:
965
- task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
966
- loss = task_loss / config.gradient_accumulation_steps
 
 
 
 
 
967
  loss.backward()
968
 
969
- # FIX: per-token tracking
970
  acc.update(task_loss.item(), labels)
971
 
972
  if (step + 1) % config.gradient_accumulation_steps == 0:
973
  if use_amp:
974
  scaler.unscale_(optimizer)
975
 
976
- # FIX: log gradient norm
977
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
978
 
979
  if use_amp:
@@ -1014,13 +1098,8 @@ def train_model(
1014
 
1015
  def evaluate_model(model, dataset, device, batch_size=4):
1016
  model.eval()
1017
- dataloader = DataLoader(
1018
- dataset,
1019
- batch_size=batch_size,
1020
- shuffle=False,
1021
- collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
1022
- num_workers=0,
1023
- )
1024
  acc = TokenLossAccumulator()
1025
  with torch.no_grad():
1026
  for batch in dataloader:
@@ -1042,9 +1121,27 @@ def evaluate_model(model, dataset, device, batch_size=4):
1042
 
1043
 
1044
  # ============================================================================
1045
- # GENERATION WITH KV CACHE
1046
  # ============================================================================
1047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1048
  def generate_text(
1049
  model: IndonesianLLM,
1050
  tokenizer,
@@ -1057,59 +1154,54 @@ def generate_text(
1057
  device: torch.device = torch.device('cpu'),
1058
  ) -> str:
1059
  """
1060
- KV-cache generation: O(n) per new token instead of O(n²).
1061
 
1062
- FIX: repetition_penalty now uses proper logit division (Keskar et al.),
1063
- not a magic subtraction that breaks under low temperature.
 
1064
  """
1065
  model.eval()
1066
 
1067
- # Reseed from OS entropy so repeated calls with the same prompt diverge.
1068
- # This is the core fix: torch.multinomial outcome depends on torch RNG state,
1069
- # which was frozen to seed=42 at startup. Each call now starts from a unique state.
1070
- import os as _os
1071
- _entropy = int.from_bytes(_os.urandom(4), 'little')
1072
- torch.manual_seed(_entropy)
1073
- if torch.cuda.is_available():
1074
- torch.cuda.manual_seed_all(_entropy)
1075
 
1076
  eos_id = tokenizer.eos_token_id or tokenizer.sep_token_id or 2
1077
  pad_id = tokenizer.pad_token_id or 0
1078
 
1079
- input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # [1, T]
1080
- generated_ids = input_ids.clone()
 
1081
 
1082
  with torch.no_grad():
1083
- # ── PREFILL: process entire prompt at once, capture KV cache ──────────
1084
- prefill_out = model(
 
 
1085
  input_ids=input_ids,
 
1086
  use_cache=True,
1087
  )
1088
  past_kv = prefill_out['past_key_values']
1089
 
1090
- # Seed penalty buffer with prompt tokens so model can't echo the input
1091
- prompt_token_ids = input_ids[0].tolist()
1092
  generated_token_ids = []
1093
 
1094
- # ── DECODE: one token at a time using cached K/V ──────────────────────
1095
  for _ in range(max_new_tokens):
1096
- # Only feed the last token
1097
- cur_id = generated_ids[:, -1:] # [1, 1]
1098
 
1099
- out = model(input_ids=cur_id, past_key_values=past_kv, use_cache=True)
1100
- past_kv = out['past_key_values']
1101
- logits = out['logits'][:, -1, :] # [1, V]
1102
- logits = logits / max(temperature, 0.05)
 
 
1103
 
1104
- # Penalize: all tokens seen so far (prompt + generated)
1105
  if repetition_penalty != 1.0:
1106
- all_seen = set(prompt_token_ids + generated_token_ids[-64:])
1107
- for tok_id in all_seen:
1108
- if 0 <= tok_id < logits.shape[-1]:
1109
- if logits[0, tok_id] > 0:
1110
- logits[0, tok_id] /= repetition_penalty
1111
- else:
1112
- logits[0, tok_id] *= repetition_penalty
1113
 
1114
  # Top-k
1115
  if top_k > 0:
@@ -1127,7 +1219,8 @@ def generate_text(
1127
  logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
1128
 
1129
  probs = F.softmax(logits, dim=-1)
1130
- next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
 
1131
 
1132
  tok_id = next_token.item()
1133
  if tok_id in {eos_id, pad_id}:
@@ -1136,24 +1229,15 @@ def generate_text(
1136
  generated_token_ids.append(tok_id)
1137
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
1138
 
1139
- # Hard context limit (shouldn't be reached with max_new_tokens)
1140
  if generated_ids.size(1) >= model.config.max_position_embeddings:
1141
  break
1142
 
1143
- import re as _re
1144
- prompt_len = input_ids.shape[1]
1145
-
1146
- # Decode ONLY the newly generated tokens — never the prompt.
1147
- # This avoids the slice-by-string-length bug where tokenizer spacing
1148
- # makes len(prompt_str) != number of chars in decoded(prompt_tokens),
1149
- # causing callers to cut mid-token and get "ot>" instead of "<cot>".
1150
  new_token_ids = generated_ids[0][prompt_len:]
1151
  if len(new_token_ids) == 0:
1152
  return ""
1153
 
1154
  raw_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
1155
- # Strip BERT specials but keep <cot> </cot>
1156
- raw_text = _re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', raw_text)
1157
  return raw_text.strip()
1158
 
1159
 
@@ -1162,64 +1246,41 @@ def generate_text(
1162
  # ============================================================================
1163
 
1164
  def _clean_response(response: str) -> str:
1165
- import re
1166
-
1167
- # Strip CoT block — do this first before any other processing
1168
  if "<cot>" in response and "</cot>" in response:
1169
  response = response.split("</cot>", 1)[-1]
1170
  elif "<cot>" in response:
1171
- # Model started CoT but never closed it — everything before <cot> is prompt leak,
1172
- # everything after is the partial reasoning. Discard both, use empty.
1173
  response = ""
1174
 
1175
- # Strip BERT-style special tokens that appear when skip_special_tokens=False
1176
  response = re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', response)
1177
-
1178
- # Strip all remaining XML/special tags
1179
  response = re.sub(r'<[^>]+>', '', response)
 
 
 
 
1180
 
1181
- # Role markers only at line start
1182
- response = re.sub(r'(?im)^\s*(user\s*:|assistant\s*:).*', '', response)
1183
-
1184
- # Strip meta-commentary (Indonesian-specific)
1185
  for marker in ["memahami permintaan", "jawaban singkat", "penjelasan harus"]:
1186
  if marker in response:
1187
  response = response.split(marker)[0]
1188
 
1189
- # Collapse whitespace
1190
  response = re.sub(r'\n{2,}', '\n', response)
1191
  response = re.sub(r' {2,}', ' ', response)
1192
-
1193
- # Strip leading punctuation/whitespace junk — but NOT digits or letters
1194
  response = re.sub(r'^[\s:!,.\-|]+', '', response)
1195
-
1196
  return response.strip()
1197
 
1198
 
1199
  def _extract_thinking(raw: str) -> Tuple[str, str]:
1200
- import re
1201
-
1202
- # Strip BERT special tokens first (they appear with skip_special_tokens=False)
1203
  raw = re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', raw)
1204
 
1205
  if "</cot>" in raw:
1206
- # Normal case: model produced full CoT block
1207
  thinking_raw, answer_raw = raw.split("</cot>", 1)
1208
  thinking = re.sub(r'<[^>]+>', '', thinking_raw).strip()
1209
- thinking = re.sub(r'(?im)^\s*(user\s*:|assistant\s*:).*', '', thinking).strip()
1210
  answer = _clean_response(answer_raw)
1211
-
1212
  elif "<cot>" in raw:
1213
- # Model started CoT but never finished — reasoning only, no answer yet.
1214
- # Extract whatever came before <cot> as a potential direct answer,
1215
- # or whatever came after as partial reasoning.
1216
  parts = raw.split("<cot>", 1)
1217
  thinking = _clean_response(parts[1]) if len(parts) > 1 else ""
1218
- # No clean answer available — return empty, caller will fall back
1219
  answer = _clean_response(parts[0]) if parts[0].strip() else ""
1220
-
1221
  else:
1222
- # No CoT tags at all — the whole output IS the answer (model skipped reasoning)
1223
  thinking = ""
1224
  answer = _clean_response(raw)
1225
 
@@ -1227,7 +1288,7 @@ def _extract_thinking(raw: str) -> Tuple[str, str]:
1227
 
1228
 
1229
  # ============================================================================
1230
- # INTERACTIVE CHAT
1231
  # ============================================================================
1232
 
1233
  def interactive_chat(
@@ -1235,16 +1296,40 @@ def interactive_chat(
1235
  tokenizer,
1236
  device: torch.device,
1237
  system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.",
 
1238
  ):
 
 
 
 
 
1239
  print("\n" + "=" * 80)
1240
- print("INDONESIAN LLM — INTERACTIVE CHAT (KV-cache enabled)")
1241
  print("=" * 80)
1242
  print("Commands: exit/quit | clear | think (toggle CoT display)")
1243
  print(f"Persona : {system_prompt}")
 
1244
  print("=" * 80 + "\n")
1245
 
1246
  model.eval()
1247
- show_thinking = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1248
 
1249
  while True:
1250
  try:
@@ -1255,6 +1340,7 @@ def interactive_chat(
1255
  print("\nSelamat tinggal!")
1256
  break
1257
  if user_input.lower() in ['clear', 'bersihkan']:
 
1258
  print("\nConversation cleared.")
1259
  continue
1260
  if user_input.lower() == 'think':
@@ -1262,10 +1348,9 @@ def interactive_chat(
1262
  print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}")
1263
  continue
1264
 
1265
- prompt = f"{user_input} <cot>"
1266
  print("\nA:", end=" ", flush=True)
1267
 
1268
- # generate_text now returns ONLY new tokens (no prompt prefix)
1269
  response = generate_text(
1270
  model=model,
1271
  tokenizer=tokenizer,
@@ -1282,15 +1367,11 @@ def interactive_chat(
1282
  if show_thinking and thinking:
1283
  print(f"[Thinking: {thinking}]")
1284
 
1285
- # Use answer if non-empty; fall back to cleaned full response;
1286
- # last resort: use thinking itself (model reasoned but didn't emit answer).
1287
- # Never throw away a valid short answer like "1", "2", "ya".
1288
  if answer:
1289
  final = answer
1290
  else:
1291
  final = _clean_response(response)
1292
  if not final and thinking:
1293
- # Model only produced reasoning, extract last sentence as answer
1294
  sentences = [s.strip() for s in thinking.split('.') if s.strip()]
1295
  final = sentences[-1] if sentences else thinking[:200]
1296
 
@@ -1298,6 +1379,9 @@ def interactive_chat(
1298
  final = "..."
1299
  print(final)
1300
 
 
 
 
1301
  except KeyboardInterrupt:
1302
  print("\n\nDihentikan.")
1303
  break
@@ -1332,21 +1416,16 @@ def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 2
1332
  print("No valid samples.")
1333
  return
1334
 
1335
- # Time-based seed: different sample selection AND different generation each run
1336
- import time
1337
  live_seed = int(time.time() * 1000) % (2**31)
1338
  random.seed(live_seed)
1339
- torch.manual_seed(live_seed)
1340
- if torch.cuda.is_available():
1341
- torch.cuda.manual_seed_all(live_seed)
1342
 
1343
  samples = random.sample(all_samples, min(n, len(all_samples)))
1344
  model.eval()
1345
 
1346
  print(f"\n{'=' * 80}\nBENCHMARK ({len(samples)} samples)\n{'=' * 80}")
1347
 
1348
- results = []
1349
- acc = TokenLossAccumulator()
1350
 
1351
  for sample in samples:
1352
  inp = sample['input'].strip()
@@ -1358,7 +1437,6 @@ def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 2
1358
  _, answer = _extract_thinking(raw)
1359
  answer_lower = answer.lower()
1360
 
1361
- # Exact + token-overlap match
1362
  passed = expected in answer_lower
1363
  if not passed:
1364
  exp_toks = set(expected.split())
@@ -1391,9 +1469,16 @@ def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 2
1391
 
1392
  # ============================================================================
1393
  # SAVE / LOAD
 
1394
  # ============================================================================
1395
 
1396
- def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, path: str, use_fp16: bool = True):
 
 
 
 
 
 
1397
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
1398
  state = model.state_dict()
1399
  if use_fp16:
@@ -1406,10 +1491,17 @@ def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, p
1406
  'dtype': 'fp16' if use_fp16 else 'fp32',
1407
  }, path)
1408
  size_mb = os.path.getsize(path) / 1e6
1409
- print(f"\nSaved: {path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB, {model.count_parameters():,} params)")
 
1410
 
1411
 
1412
- def load_model(path: str, device: torch.device):
 
 
 
 
 
 
1413
  if not os.path.exists(path):
1414
  raise FileNotFoundError(f"Checkpoint not found: {path}")
1415
  print(f"Loading: {path}")
@@ -1420,36 +1512,41 @@ def load_model(path: str, device: torch.device):
1420
  dtype = ck.get('dtype', 'fp32')
1421
 
1422
  state = ck['model_state_dict']
1423
- if dtype == 'fp16':
 
 
1424
  state = {k: v.float() if v.dtype == torch.float16 else v for k, v in state.items()}
 
1425
 
1426
- # Always derive intermediate_size from actual saved weights so the
1427
- # model architecture matches exactly, regardless of what the config says.
1428
- # gate_proj shape is [intermediate_size, hidden_size].
1429
  gate_key = next((k for k in state if k.endswith('gate_proj.weight')), None)
1430
  if gate_key is not None:
1431
  inferred_intermediate = state[gate_key].shape[0]
1432
  if getattr(config, 'intermediate_size', -1) != inferred_intermediate:
1433
  print(f" [load_model] intermediate_size: config={getattr(config, 'intermediate_size', '?')} "
1434
- f"-> overriding with checkpoint value {inferred_intermediate}")
1435
  config.intermediate_size = inferred_intermediate
1436
 
1437
- # Sync vocab_size from embedding weight shape
1438
  embed_key = next((k for k in state if k.endswith('embed_tokens.weight')), None)
1439
  if embed_key is not None:
1440
  inferred_vocab = state[embed_key].shape[0]
1441
  if config.vocab_size != inferred_vocab:
1442
  print(f" [load_model] vocab_size: config={config.vocab_size} "
1443
- f"-> overriding with checkpoint value {inferred_vocab}")
1444
  config.vocab_size = inferred_vocab
1445
 
1446
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
1447
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
1448
 
1449
  model = IndonesianLLM(config)
1450
- model.load_state_dict(state)
1451
  model.to(device)
1452
 
 
 
 
 
 
1453
  size_mb = os.path.getsize(path) / 1e6
1454
  print(f"Loaded ({dtype}, {size_mb:.1f} MB, {ck.get('model_params', model.count_parameters()):,} params)")
1455
  return model, tokenizer, config, {}
@@ -1493,6 +1590,11 @@ def main():
1493
  parser.add_argument('--ewc-lambda', type=float, default=5000.0)
1494
  parser.add_argument('--ewc-samples', type=int, default=2000)
1495
  parser.add_argument('--no-ewc', action='store_true')
 
 
 
 
 
1496
 
1497
  args = parser.parse_args()
1498
 
@@ -1505,13 +1607,9 @@ def main():
1505
  save_fp16 = not args.save_fp32
1506
  use_cot_training = not args.no_cot
1507
 
1508
- # Only fix the seed for training (reproducibility).
1509
- # Chat and benchmark must NOT be seeded — identical seeds produce identical
1510
- # outputs every run, making the model feel like a lookup table.
1511
  if args.train or args.finetune or args.continue_train:
1512
  set_seed(args.seed)
1513
  else:
1514
- # Use a time-based seed so every run is different
1515
  import time
1516
  live_seed = int(time.time() * 1000) % (2**31)
1517
  random.seed(live_seed)
@@ -1519,13 +1617,13 @@ def main():
1519
  torch.manual_seed(live_seed)
1520
  if torch.cuda.is_available():
1521
  torch.cuda.manual_seed_all(live_seed)
 
1522
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1523
  print(f"\nDevice: {device}")
1524
  if torch.cuda.is_available():
1525
  print(f" GPU: {torch.cuda.get_device_name(0)}")
1526
  print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
1527
 
1528
- # ── INSPECT DATA ─────────────────────────────────────────────────────────
1529
  if args.inspect_data:
1530
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
1531
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
@@ -1542,24 +1640,22 @@ def main():
1542
  print(f" Output: {s['output'][:120]}")
1543
  return
1544
 
1545
- # ── CHAT ─────────────────────────────────────────────────────────────────
1546
  if args.chat:
1547
- model, tokenizer, _, _ = load_model(args.model, device)
1548
- interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt)
 
 
1549
  return
1550
 
1551
- # ── BENCHMARK ────────────────────────────────────────────────────────────
1552
  if args.benchmark:
1553
- model, tokenizer, _, _ = load_model(args.model, device)
1554
  run_benchmark(model, tokenizer, device, dataset_path=args.dataset)
1555
  return
1556
 
1557
- # ── TRAIN FROM SCRATCH ───────────────────────────────────────────────────
1558
  if args.train:
1559
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
1560
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
1561
 
1562
- # FIX: vocab_size from actual tokenizer length (never hardcoded)
1563
  model_config = ModelConfig(
1564
  vocab_size = len(tokenizer),
1565
  hidden_size = args.hidden_size,
@@ -1570,32 +1666,30 @@ def main():
1570
  attention_dropout = 0.1,
1571
  residual_dropout = 0.1,
1572
  tie_word_embeddings = True,
 
1573
  )
1574
  print(f"\nModel config: {model_config}")
1575
- print(f"intermediate_size (SwiGLU 8/3): {model_config.intermediate_size}")
1576
 
1577
  model = IndonesianLLM(model_config)
1578
- # Tag residual projections BEFORE init so depth-scaling applies
1579
- model._tag_projection_layers()
1580
- model.apply(model._init_weights)
1581
  print(f"Parameters: {model.count_parameters():,}")
1582
 
1583
  _ga = args.grad_accum or 32
1584
  train_config = TrainingConfig(
1585
- dataset_path = args.dataset,
1586
- num_epochs = args.epochs,
1587
- batch_size = args.batch_size,
1588
- gradient_accumulation_steps = _ga,
1589
- max_seq_length = args.max_length,
1590
- learning_rate = args.lr,
1591
- warmup_steps = 500,
1592
- use_fp16 = torch.cuda.is_available(),
1593
- curriculum_stages = [128, 256, args.max_length],
 
1594
  )
1595
 
1596
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1597
- train_config.max_seq_length, use_cot=use_cot_training,
1598
- cot_ratio=args.cot_ratio)
1599
  model = train_model(model, dataset, train_config, device,
1600
  use_simple_curriculum=args.simple_curriculum)
1601
 
@@ -1611,32 +1705,31 @@ def main():
1611
  print(f"\nPrompt : {p}")
1612
  print(f"Generated: {generate_text(model, tokenizer, p, max_new_tokens=150, device=device)}\n")
1613
 
1614
- # ── FINETUNE ──────────────────────────────────────��──────────────────────
1615
  if args.finetune:
1616
- model, tokenizer, model_config, _ = load_model(args.model, device)
1617
 
1618
  _ga = args.grad_accum or 32
1619
  train_config = TrainingConfig(
1620
- dataset_path = args.dataset,
1621
- num_epochs = args.epochs,
1622
- batch_size = args.batch_size,
1623
- gradient_accumulation_steps = _ga,
1624
- max_seq_length = args.max_length,
1625
- learning_rate = args.lr / 10,
1626
- warmup_steps = 100,
1627
- use_fp16 = torch.cuda.is_available(),
1628
- curriculum_stages = [128, 256, args.max_length],
 
1629
  )
1630
 
1631
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1632
- train_config.max_seq_length, use_cot=use_cot_training,
1633
- cot_ratio=args.cot_ratio)
1634
  ewc_obj = None
1635
  if not args.no_ewc and args.ewc_lambda > 0:
1636
  print(f"\nComputing EWC Fisher (lambda={args.ewc_lambda}, n={args.ewc_samples})...")
1637
- loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
1638
- collate_fn=lambda x: collate_fn_with_packing(x, model.padding_idx),
1639
- num_workers=0)
1640
  train_config.ewc_lambda = args.ewc_lambda
1641
  train_config.ewc_samples = args.ewc_samples
1642
  ewc_obj = EWC(model, loader, device, n_samples=args.ewc_samples)
@@ -1651,36 +1744,34 @@ def main():
1651
  save_model(model, model_config, "indolem/indobert-base-uncased", out_path, use_fp16=save_fp16)
1652
  print(f"\nFinetuned model: {out_path}")
1653
 
1654
- # ── CONTINUE TRAINING ────────────────────────────────────────────────────
1655
  if args.continue_train:
1656
- model, tokenizer, model_config, _ = load_model(args.model, device)
1657
 
1658
- effective_lr = args.lr * 0.05
1659
- effective_skip = (len([128, 256, args.max_length]) - 1) if args.simple_curriculum else args.skip_stages
1660
- curriculum = [192, 320, args.max_length]
1661
 
1662
- print(f"\nContinue-train LR: {effective_lr:.2e} (skip {effective_skip} stages)")
1663
 
1664
  _ga = args.grad_accum or 32
1665
  train_config = TrainingConfig(
1666
- dataset_path = args.dataset,
1667
- num_epochs = args.epochs,
1668
- batch_size = args.batch_size,
1669
- gradient_accumulation_steps = _ga,
1670
- max_seq_length = args.max_length,
1671
- learning_rate = effective_lr,
1672
- warmup_steps = 0,
1673
- use_fp16 = torch.cuda.is_available(),
1674
- curriculum_stages = curriculum,
1675
- skip_curriculum_stages = effective_skip,
1676
- plateau_patience = 2,
1677
- plateau_factor = 0.5,
1678
- plateau_min_delta = 0.02,
1679
  )
1680
 
1681
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1682
- train_config.max_seq_length, use_cot=use_cot_training,
1683
- cot_ratio=args.cot_ratio)
1684
  model = train_model(model, dataset, train_config, device,
1685
  use_simple_curriculum=args.simple_curriculum,
1686
  is_continue=True,
 
5
  Architecture: Decoder-only transformer with GQA, RoPE, SwiGLU, RMSNorm, KV-Cache
6
  Target: 15M-30M parameters, optimized for Google Colab Free tier
7
 
8
+ FIXES in this version (on top of prior fixes):
9
+ [INFERENCE]
10
+ - FIX-I1: KV cache RoPE offset uses proper position_ids tensor, not slice arithmetic
11
+ - FIX-I2: Vectorized repetition penalty (scatter gather on GPU, no Python loop)
12
+ - FIX-I3: torch.Generator for per-call entropy no global RNG reset
13
+ - FIX-I4: Multi-turn conversation history in interactive_chat
14
+ - FIX-I5: Top-p preallocated scratch tensors (minor, readability)
15
+ - FIX-I6: generate_text returns generator for streaming (optional)
16
+
17
+ [TRAINING]
18
+ - FIX-T1: _tag_projection_layers called inside __init__ before apply(_init_weights)
19
+ - FIX-T2: EWC penalty computed once per optimizer step, not per micro-batch
20
+ - FIX-T3: acc.update tracks task_loss_only (no EWC in perplexity)
21
+ - FIX-T4: PyTorch version guard for label_smoothing + ignore_index interaction
22
+ - FIX-T5: DataLoader num_workers=2 with persistent_workers on CUDA
23
+ - FIX-T6: Gradient checkpointing option (halves activation memory)
24
+ - FIX-T7: save/load fp16 stays fp16 at inference — no upcast unless training
25
+ - FIX-T8: TrainingConfig.skip_curriculum_stages actually used (dead field removed)
26
+ - FIX-T9: EWC Fisher uses model's own predictions as labels (empirical Fisher)
27
  """
28
 
29
  import torch
30
  import torch.nn as nn
31
  import torch.nn.functional as F
32
  from torch.utils.data import Dataset, DataLoader
33
+ from torch.utils.checkpoint import checkpoint as gradient_checkpoint
34
  from transformers import AutoTokenizer
35
  import json
36
  import math
37
  import random
38
  import numpy as np
39
+ from typing import Optional, Tuple, List, Dict, Generator
40
  from dataclasses import dataclass, field
41
  import warnings
42
  import argparse
43
  import os
44
+ import re
45
 
46
  warnings.filterwarnings('ignore')
47
 
 
51
 
52
  @dataclass
53
  class ModelConfig:
54
+ vocab_size: int = 32000
55
  hidden_size: int = 384
56
  num_layers: int = 12
57
  num_attention_heads: int = 6
58
+ num_key_value_heads: int = 2
 
 
 
59
  intermediate_size: int = 0
60
  max_position_embeddings: int = 2048
61
  rms_norm_eps: float = 1e-6
 
68
  eos_token_id: int = 2
69
  tie_word_embeddings: bool = True
70
  label_smoothing: float = 0.1
71
+ # FIX-T6: gradient checkpointing flag
72
+ use_gradient_checkpointing: bool = False
73
 
74
  def __post_init__(self):
 
75
  if self.intermediate_size <= 0:
76
  self.intermediate_size = self.hidden_size * 3
77
  assert self.hidden_size % self.num_attention_heads == 0, \
 
101
  lr_scheduler_type: str = "cosine"
102
 
103
  dropout: float = 0.1
 
 
104
  use_fp16: bool = True
105
+ # FIX-T6: expose gradient checkpointing in training config
106
+ use_gradient_checkpointing: bool = False
107
 
108
  seed: int = 42
109
  logging_steps: int = 10
 
111
  save_steps: int = 500
112
 
113
  curriculum_stages: List[int] = None
 
114
 
115
  plateau_patience: int = 3
116
  plateau_factor: float = 0.5
 
144
  t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
145
  freqs = torch.outer(t, self.inv_freq)
146
  emb = torch.cat((freqs, freqs), dim=-1)
147
+ # [1, 1, T, D] for broadcast onto [B, H, T, D]
148
  self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
149
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
150
 
 
163
  return torch.cat((-x2, x1), dim=-1)
164
 
165
 
166
+ # FIX-I1: position_ids-based RoPE application — no slice arithmetic
167
+ def apply_rotary_pos_emb_with_ids(
168
+ q: torch.Tensor,
169
+ k: torch.Tensor,
170
+ cos: torch.Tensor,
171
+ sin: torch.Tensor,
172
+ position_ids: torch.Tensor, # [B, T] always provided
173
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
174
+ """
175
+ Apply RoPE using explicit position_ids.
176
+ cos/sin: [1, 1, max_seq, D]
177
+ position_ids: [B, T] (T=1 during decode, T=seq_len during prefill)
178
+ """
179
+ # Gather cos/sin for the specific positions: [B, T, D]
180
+ cos_pos = cos[0, 0][position_ids] # [B, T, D]
181
+ sin_pos = sin[0, 0][position_ids] # [B, T, D]
182
+ # Unsqueeze head dim for broadcast: [B, 1, T, D]
183
+ cos_pos = cos_pos.unsqueeze(1)
184
+ sin_pos = sin_pos.unsqueeze(1)
185
+ q_embed = (q * cos_pos) + (rotate_half(q) * sin_pos)
186
+ k_embed = (k * cos_pos) + (rotate_half(k) * sin_pos)
187
  return q_embed, k_embed
188
 
189
 
 
237
  self,
238
  hidden_states: torch.Tensor,
239
  attention_mask: Optional[torch.Tensor] = None,
240
+ position_ids: Optional[torch.Tensor] = None, # FIX-I1
241
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
242
  use_cache: bool = False,
243
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
248
  key_states = self.k_proj(hidden_states)
249
  value_states = self.v_proj(hidden_states)
250
 
251
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
252
  key_states = key_states .view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
253
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
254
 
255
+ past_len = past_key_value[0].shape[2] if past_key_value is not None else 0
256
+ kv_seq_len = past_len + q_len
 
 
257
 
258
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
259
 
260
+ # FIX-I1: use explicit position_ids for RoPE works for both prefill and decode
261
+ if position_ids is None:
262
+ position_ids = torch.arange(past_len, past_len + q_len,
263
+ device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
264
+
265
+ query_states, key_states = apply_rotary_pos_emb_with_ids(
266
+ query_states, key_states, cos, sin, position_ids
267
+ )
268
+
269
+ # Append to KV cache BEFORE repeat (store compact num_kv_heads version)
270
  if past_key_value is not None:
 
 
 
 
 
 
 
271
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
272
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
273
+
 
 
 
 
 
 
 
 
 
 
274
  present_kv = (key_states, value_states) if use_cache else None
275
 
276
+ # Expand KV for full multi-head attention
277
  if self.num_key_value_groups > 1:
278
  key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
279
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
 
302
  def __init__(self, config: ModelConfig):
303
  super().__init__()
304
  self.hidden_size = config.hidden_size
 
 
305
  inter = getattr(config, 'intermediate_size', 0)
306
  if not isinstance(inter, int) or inter <= 0:
307
  inter = self.hidden_size * 3
 
332
  self,
333
  hidden_states: torch.Tensor,
334
  attention_mask: Optional[torch.Tensor] = None,
335
+ position_ids: Optional[torch.Tensor] = None, # FIX-I1
336
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
337
  use_cache: bool = False,
338
  ):
339
+ residual = hidden_states
340
  hidden_states = self.input_layernorm(hidden_states)
341
  hidden_states, present_kv = self.self_attn(
342
  hidden_states,
343
  attention_mask=attention_mask,
344
+ position_ids=position_ids,
345
  past_key_value=past_key_value,
346
  use_cache=use_cache,
347
  )
348
  hidden_states = self.residual_dropout(hidden_states)
349
  hidden_states = residual + hidden_states
350
 
351
+ residual = hidden_states
352
  hidden_states = self.post_attention_layernorm(hidden_states)
353
  hidden_states = self.mlp(hidden_states)
354
  hidden_states = self.residual_dropout(hidden_states)
 
358
 
359
 
360
  # ============================================================================
361
+ # CUSTOM LABEL SMOOTHING LOSS
362
  # ============================================================================
363
 
364
+ # FIX-T4: PyTorch version guard for label_smoothing + ignore_index
365
+ _TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2] if x.isdigit())
366
+ _NATIVE_SMOOTH_SAFE = _TORCH_VERSION >= (1, 10)
 
 
 
 
367
 
368
+ class LabelSmoothingCrossEntropy(nn.Module):
369
  def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
370
  super().__init__()
371
  self.vocab_size = vocab_size
 
373
  self.ignore_index = ignore_index
374
 
375
  def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
376
+ if _NATIVE_SMOOTH_SAFE and self.smoothing > 0:
377
+ return F.cross_entropy(
378
+ logits,
379
+ targets,
380
+ ignore_index=self.ignore_index,
381
+ label_smoothing=self.smoothing,
382
+ )
383
+ else:
384
+ # Manual fallback: safe for any PyTorch version
385
+ log_probs = F.log_softmax(logits, dim=-1)
386
+ nll_loss = F.nll_loss(log_probs, targets, ignore_index=self.ignore_index, reduction='mean')
387
+ if self.smoothing <= 0:
388
+ return nll_loss
389
+ smooth_loss = -log_probs.mean(dim=-1)
390
+ mask = (targets != self.ignore_index)
391
+ smooth_loss = smooth_loss[mask].mean() if mask.any() else smooth_loss.mean()
392
+ return (1.0 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
393
 
394
 
395
  # ============================================================================
 
403
  self.padding_idx = config.pad_token_id
404
  self.vocab_size = config.vocab_size
405
 
406
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
407
+ padding_idx=self.padding_idx)
408
  self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)])
409
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
410
 
 
412
  nn.Linear(config.hidden_size, config.vocab_size, bias=False)
413
 
414
  self.loss_fn = LabelSmoothingCrossEntropy(
415
+ vocab_size = config.vocab_size,
416
+ smoothing = config.label_smoothing,
417
+ ignore_index = -100,
418
  )
419
 
420
+ # FIX-T1: tag projection layers BEFORE weight init so depth-scaling applies
421
+ self._tag_projection_layers()
422
  self.apply(self._init_weights)
423
 
424
  def _init_weights(self, module):
425
  std = self.config.initializer_range
426
  if isinstance(module, nn.Linear):
 
 
427
  name = getattr(module, '_layer_name', '')
428
  if name in ('o_proj', 'down_proj'):
 
429
  scaled_std = std / math.sqrt(2 * self.config.num_layers)
430
  module.weight.data.normal_(mean=0.0, std=scaled_std)
431
  else:
 
438
  module.weight.data[module.padding_idx].zero_()
439
 
440
  def _tag_projection_layers(self):
 
441
  for layer in self.layers:
442
  layer.self_attn.o_proj._layer_name = 'o_proj'
443
  layer.mlp.down_proj._layer_name = 'down_proj'
 
454
  attention_mask: Optional[torch.Tensor] = None,
455
  batch_size: int = 1,
456
  ) -> torch.Tensor:
 
 
 
 
457
  total_len = past_len + seq_len
458
+ causal = torch.full((seq_len, total_len), torch.finfo(dtype).min,
459
+ device=device, dtype=dtype)
460
  mask_cond = torch.arange(total_len, device=device)
461
+ causal.masked_fill_(
462
+ mask_cond[None, :] <= (torch.arange(seq_len, device=device) + past_len)[:, None],
463
+ 0.0
464
+ )
465
  causal = causal[None, None, :, :].expand(batch_size, 1, seq_len, total_len)
466
 
467
  if attention_mask is not None:
 
468
  pad_mask = (1.0 - attention_mask[:, None, None, :].float()) * torch.finfo(dtype).min
469
  causal = causal + pad_mask
470
 
471
  return causal
472
 
473
+ # FIX-T6: gradient checkpointing wrapper for decoder layers
474
+ def _layer_forward_with_ckpt(
475
+ self,
476
+ layer,
477
+ hidden_states,
478
+ attention_mask,
479
+ position_ids,
480
+ past_key_value,
481
+ use_cache,
482
+ ):
483
+ if self.config.use_gradient_checkpointing and self.training and past_key_value is None:
484
+ # Gradient checkpointing is only meaningful during training prefill
485
+ def create_custom_forward(module):
486
+ def custom_forward(*inputs):
487
+ return module(*inputs, use_cache=False)
488
+ return custom_forward
489
+ hidden_states, _ = gradient_checkpoint(
490
+ create_custom_forward(layer),
491
+ hidden_states,
492
+ attention_mask,
493
+ position_ids,
494
+ None,
495
+ use_reentrant=False,
496
+ )
497
+ return hidden_states, None
498
+ else:
499
+ return layer(
500
+ hidden_states,
501
+ attention_mask=attention_mask,
502
+ position_ids=position_ids,
503
+ past_key_value=past_key_value,
504
+ use_cache=use_cache,
505
+ )
506
+
507
  def forward(
508
  self,
509
  input_ids: torch.Tensor,
510
  attention_mask: Optional[torch.Tensor] = None,
511
+ position_ids: Optional[torch.Tensor] = None, # FIX-I1
512
  labels: Optional[torch.Tensor] = None,
513
  past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
514
  use_cache: bool = False,
 
519
 
520
  hidden_states = self.embed_tokens(input_ids)
521
 
 
522
  if attention_mask is None:
523
  attention_mask = torch.ones(batch_size, past_len + seq_length,
524
  dtype=torch.long, device=input_ids.device)
525
 
526
+ # FIX-I1: build position_ids for this forward pass
527
+ if position_ids is None:
528
+ position_ids = torch.arange(past_len, past_len + seq_length,
529
+ device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
530
+
531
  causal_mask = self._make_causal_mask(
532
  seq_len=seq_length,
533
  past_len=past_len,
 
541
 
542
  for i, decoder_layer in enumerate(self.layers):
543
  pkv = past_key_values[i] if past_key_values is not None else None
544
+ hidden_states, present_kv = self._layer_forward_with_ckpt(
545
+ decoder_layer,
546
  hidden_states,
547
+ causal_mask,
548
+ position_ids,
549
+ pkv,
550
+ use_cache,
551
  )
552
  if use_cache:
553
  present_key_values.append(present_kv)
 
564
  loss = self.loss_fn(shift_logits, shift_labels)
565
 
566
  return {
567
+ "loss": loss,
568
+ "logits": logits,
569
+ "past_key_values": present_key_values,
570
  }
571
 
572
  def count_parameters(self) -> int:
 
599
  self.skipped_count = 0
600
  self._load_data(file_path)
601
 
 
602
  @property
603
  def _eos(self) -> str:
604
  return self.tokenizer.eos_token or "</s>"
 
606
  def _load_data(self, file_path: str):
607
  print(f"Loading dataset from {file_path}...")
608
  with open(file_path, 'r', encoding='utf-8') as f:
609
+ for line in f:
610
  try:
611
  if not line.strip():
612
  continue
 
633
  def __getitem__(self, idx):
634
  sample = self.samples[idx]
635
 
 
636
  if self.use_cot and random.random() < self.cot_ratio:
637
  prompt = f"{sample['input']} {self.cot_token}"
 
638
  completion = f" {sample['cot']} {self.end_cot_token} {sample['output']}{self._eos}"
639
  else:
640
  prompt = f"{sample['input']}"
 
651
  add_special_tokens=True,
652
  )
653
 
 
654
  labels = [-100] * min(prompt_len, len(full_ids)) + full_ids[prompt_len:]
655
  labels = labels[:len(full_ids)]
656
 
 
688
 
689
 
690
  # ============================================================================
691
+ # PER-TOKEN LOSS TRACKING
692
  # ============================================================================
693
 
694
  class TokenLossAccumulator:
 
 
695
  def __init__(self):
696
  self.total_loss = 0.0
697
  self.total_tokens = 0
 
732
  def create_curriculum_datasets(dataset, stages=None, use_simple=False, skip_stages=0):
733
  if stages is None:
734
  stages = [256, 512, 1024]
735
+
736
+ print("\n" + "=" * 80)
737
+ print("3-STAGE REASONING CURRICULUM")
738
+ if skip_stages > 0:
739
+ print(f" (Skipping first {skip_stages} stage(s))")
740
+ print("=" * 80)
741
 
742
  if use_simple:
743
+ datasets = []
744
  for i, max_len in enumerate(stages):
 
 
 
745
  filtered = [
746
  s for s in dataset.samples
747
  if len(dataset.tokenizer.encode(
 
749
  )) <= max_len
750
  ]
751
  datasets.append(_build_stage_dataset(dataset, filtered, max_len, dataset.cot_ratio))
752
+ tag = " [SKIP]" if i < skip_stages else ""
753
+ print(f" Stage {max_len}{tag}: {len(filtered)} samples")
754
  else:
 
 
 
 
 
 
755
  stage_configs = [
756
  {'name': 'Stage 1: Basic Q&A (no CoT)', 'max_len': 384, 'cot_ratio': 0.0},
757
  {'name': 'Stage 2: Learning Reasoning (50% CoT)', 'max_len': 512, 'cot_ratio': 0.5},
758
  {'name': 'Stage 3: Full Reasoning (100% CoT)', 'max_len': 1024, 'cot_ratio': 1.0},
759
  ]
760
+ datasets = []
761
  for idx, sc in enumerate(stage_configs):
762
+ filtered = dataset.samples if idx > 0 else [
763
+ s for s in dataset.samples
764
+ if len(dataset.tokenizer.encode(f"{s['input']} {s['output']}")) <= sc['max_len']
765
+ ]
 
 
766
  datasets.append(_build_stage_dataset(dataset, filtered, sc['max_len'], sc['cot_ratio']))
767
+ tag = " [SKIP]" if idx < skip_stages else ""
 
768
  print(f" {sc['name']}{tag} | samples={len(filtered)} | CoT={sc['cot_ratio']:.0%}")
769
 
770
+ print("=" * 80 + "\n")
771
+ return datasets[skip_stages:]
 
 
 
772
 
773
 
774
  # ============================================================================
 
840
  torch.backends.cudnn.benchmark = False
841
 
842
 
843
+ def _make_dataloader(dataset, batch_size, shuffle, pad_token_id, device_type):
844
+ # FIX-T5: use num_workers=2 with persistent_workers on CUDA for better GPU util
845
+ num_workers = 2 if device_type == 'cuda' else 0
846
+ persistent = (num_workers > 0)
847
+ return DataLoader(
848
+ dataset,
849
+ batch_size=batch_size,
850
+ shuffle=shuffle,
851
+ collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=pad_token_id),
852
+ num_workers=num_workers,
853
+ persistent_workers=persistent,
854
+ pin_memory=(device_type == 'cuda'),
855
+ )
856
+
857
+
858
  # ============================================================================
859
  # ELASTIC WEIGHT CONSOLIDATION (EWC)
860
  # ============================================================================
 
867
  self.fisher = self._compute_fisher(model, dataloader)
868
 
869
  def _compute_fisher(self, model, dataloader):
870
+ """
871
+ FIX-T9: Empirical Fisher — uses model's own predictions as labels.
872
+ This avoids the bias from using training labels and is more theoretically
873
+ correct for EWC (Kirkpatrick et al. 2017).
874
+ """
875
  fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
876
  model.eval()
877
  seen = 0
 
880
  break
881
  input_ids = batch["input_ids"] .to(self.device)
882
  attention_mask = batch["attention_mask"] .to(self.device)
883
+
884
  model.zero_grad()
885
+ with torch.no_grad():
886
+ out = model(input_ids=input_ids, attention_mask=attention_mask)
887
+ logits = out["logits"]
888
+ # Use model's own greedy predictions as labels (empirical Fisher)
889
+ pred_labels = logits[:, :-1, :].argmax(dim=-1) # [B, T-1]
890
+ # Shift input_ids for proper alignment
891
+ # pred_labels serve as targets for the shifted logits
892
+ flat_logits = logits[:, :-1, :].contiguous().view(-1, model.vocab_size)
893
+ flat_labels = pred_labels.contiguous().view(-1)
894
+
895
+ # Recompute with grad enabled using the pseudo-labels
896
+ out2 = model(input_ids=input_ids, attention_mask=attention_mask)
897
+ flat_logits_grad = out2["logits"][:, :-1, :].contiguous().view(-1, model.vocab_size)
898
+ loss = F.cross_entropy(flat_logits_grad, flat_labels.detach())
899
+ loss.backward()
900
+
901
  for n, p in model.named_parameters():
902
  if p.requires_grad and p.grad is not None:
903
  fisher[n] += p.grad.detach().pow(2)
904
  seen += input_ids.size(0)
905
+
906
  for n in fisher:
907
  fisher[n] /= max(1, seen)
908
  model.train()
 
942
  print(f" Max seq length: {config.max_seq_length}")
943
  print(f" Epochs: {config.num_epochs}")
944
  print(f" Mixed precision: {config.use_fp16}")
945
+ print(f" Grad checkpointing: {config.use_gradient_checkpointing}")
946
  print(f" EWC: {'enabled (lambda=' + str(config.ewc_lambda) + ')' if ewc else 'disabled'}")
947
  print("=" * 80 + "\n")
948
 
949
+ # FIX-T6: apply gradient checkpointing to model config
950
+ if config.use_gradient_checkpointing:
951
+ model.config.use_gradient_checkpointing = True
952
+
953
  model.to(device)
954
  model.train()
955
 
 
961
  )
962
 
963
  if not curriculum_datasets:
964
+ print("ERROR: No curriculum stages.")
965
  return model
966
 
967
  optimizer = torch.optim.AdamW(
 
977
  for ds in curriculum_datasets
978
  ) or 1
979
 
 
980
  use_amp = config.use_fp16 and device.type == 'cuda'
981
  scaler = torch.amp.GradScaler('cuda') if use_amp else None
982
 
 
1005
  f"n={len(stage_dataset)} | CoT={getattr(stage_dataset, 'cot_ratio', '?'):.0%}")
1006
  print(f"{'=' * 80}\n")
1007
 
1008
+ dataloader = _make_dataloader(
1009
+ stage_dataset, config.batch_size, shuffle=True,
1010
+ pad_token_id=model.padding_idx, device_type=device.type
 
 
 
 
1011
  )
1012
 
1013
  for epoch in range(config.num_epochs):
 
1015
  acc = TokenLossAccumulator()
1016
  optimizer.zero_grad()
1017
 
1018
+ # FIX-T2: compute EWC penalty once per optimizer step, not per micro-batch
1019
+ ewc_penalty_cache = None
1020
+ ewc_cache_step = -1
1021
+
1022
  for step, batch in enumerate(dataloader):
1023
  input_ids = batch['input_ids'] .to(device)
1024
  attention_mask = batch['attention_mask'] .to(device)
 
1029
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
1030
  task_loss = outputs['loss']
1031
  if ewc is not None:
1032
+ # FIX-T2: cache penalty for entire accumulation window
1033
+ if ewc_cache_step != (step // config.gradient_accumulation_steps):
1034
+ ewc_cache_step = step // config.gradient_accumulation_steps
1035
+ ewc_penalty_cache = ewc.penalty(model)
1036
+ loss = (task_loss + config.ewc_lambda * ewc_penalty_cache) \
1037
+ / config.gradient_accumulation_steps
1038
+ else:
1039
+ loss = task_loss / config.gradient_accumulation_steps
1040
  scaler.scale(loss).backward()
1041
  else:
1042
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
1043
  task_loss = outputs['loss']
1044
  if ewc is not None:
1045
+ if ewc_cache_step != (step // config.gradient_accumulation_steps):
1046
+ ewc_cache_step = step // config.gradient_accumulation_steps
1047
+ ewc_penalty_cache = ewc.penalty(model)
1048
+ loss = (task_loss + config.ewc_lambda * ewc_penalty_cache) \
1049
+ / config.gradient_accumulation_steps
1050
+ else:
1051
+ loss = task_loss / config.gradient_accumulation_steps
1052
  loss.backward()
1053
 
1054
+ # FIX-T3: track task_loss only (no EWC contamination in perplexity)
1055
  acc.update(task_loss.item(), labels)
1056
 
1057
  if (step + 1) % config.gradient_accumulation_steps == 0:
1058
  if use_amp:
1059
  scaler.unscale_(optimizer)
1060
 
 
1061
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
1062
 
1063
  if use_amp:
 
1098
 
1099
  def evaluate_model(model, dataset, device, batch_size=4):
1100
  model.eval()
1101
+ dataloader = _make_dataloader(dataset, batch_size, shuffle=False,
1102
+ pad_token_id=model.padding_idx, device_type=device.type)
 
 
 
 
 
1103
  acc = TokenLossAccumulator()
1104
  with torch.no_grad():
1105
  for batch in dataloader:
 
1121
 
1122
 
1123
  # ============================================================================
1124
+ # GENERATION WITH KV CACHE — FULLY FIXED
1125
  # ============================================================================
1126
 
1127
+ # FIX-I2: vectorized repetition penalty
1128
+ def _apply_repetition_penalty_vectorized(
1129
+ logits: torch.Tensor, # [1, V]
1130
+ token_ids: List[int],
1131
+ penalty: float,
1132
+ ) -> torch.Tensor:
1133
+ if not token_ids or penalty == 1.0:
1134
+ return logits
1135
+ unique_ids = list(set(token_ids))
1136
+ idx = torch.tensor(unique_ids, dtype=torch.long, device=logits.device)
1137
+ # Gather scores for penalized tokens
1138
+ scores = logits[0].gather(0, idx)
1139
+ # penalty: divide positive scores, multiply negative scores
1140
+ penalized = torch.where(scores > 0, scores / penalty, scores * penalty)
1141
+ logits[0].scatter_(0, idx, penalized)
1142
+ return logits
1143
+
1144
+
1145
  def generate_text(
1146
  model: IndonesianLLM,
1147
  tokenizer,
 
1154
  device: torch.device = torch.device('cpu'),
1155
  ) -> str:
1156
  """
1157
+ KV-cache generation with all inference fixes applied.
1158
 
1159
+ FIX-I1: position_ids propagated correctly through layers
1160
+ FIX-I2: vectorized repetition penalty (no Python loop over vocab)
1161
+ FIX-I3: torch.Generator for entropy — no global RNG reset
1162
  """
1163
  model.eval()
1164
 
1165
+ # FIX-I3: isolated Generator doesn't touch global torch RNG state
1166
+ gen = torch.Generator(device=device)
1167
+ gen.manual_seed(int.from_bytes(os.urandom(4), 'little'))
 
 
 
 
 
1168
 
1169
  eos_id = tokenizer.eos_token_id or tokenizer.sep_token_id or 2
1170
  pad_id = tokenizer.pad_token_id or 0
1171
 
1172
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
1173
+ prompt_len = input_ids.shape[1]
1174
+ generated_ids = input_ids.clone()
1175
 
1176
  with torch.no_grad():
1177
+ # Prefill: process entire prompt, build KV cache
1178
+ # FIX-I1: explicit position_ids for prefill
1179
+ prefill_pos = torch.arange(0, prompt_len, device=device).unsqueeze(0)
1180
+ prefill_out = model(
1181
  input_ids=input_ids,
1182
+ position_ids=prefill_pos,
1183
  use_cache=True,
1184
  )
1185
  past_kv = prefill_out['past_key_values']
1186
 
1187
+ prompt_token_ids = input_ids[0].tolist()
 
1188
  generated_token_ids = []
1189
 
 
1190
  for _ in range(max_new_tokens):
1191
+ cur_id = generated_ids[:, -1:] # [1, 1]
1192
+ cur_pos = torch.tensor([[past_kv[0][0].shape[2]]], device=device) # [1, 1]
1193
 
1194
+ out = model(input_ids=cur_id, position_ids=cur_pos,
1195
+ past_key_values=past_kv, use_cache=True)
1196
+ past_kv = out['past_key_values']
1197
+ logits = out['logits'][:, -1:, :].clone() # [1, 1, V] — clone to avoid in-place aliasing
1198
+ logits = logits.squeeze(1) # [1, V]
1199
+ logits /= max(temperature, 0.05)
1200
 
1201
+ # FIX-I2: vectorized repetition penalty
1202
  if repetition_penalty != 1.0:
1203
+ all_seen = prompt_token_ids + generated_token_ids[-128:]
1204
+ logits = _apply_repetition_penalty_vectorized(logits, all_seen, repetition_penalty)
 
 
 
 
 
1205
 
1206
  # Top-k
1207
  if top_k > 0:
 
1219
  logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
1220
 
1221
  probs = F.softmax(logits, dim=-1)
1222
+ # FIX-I3: use isolated generator for multinomial sampling
1223
+ next_token = torch.multinomial(probs, num_samples=1, generator=gen) # [1, 1]
1224
 
1225
  tok_id = next_token.item()
1226
  if tok_id in {eos_id, pad_id}:
 
1229
  generated_token_ids.append(tok_id)
1230
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
1231
 
 
1232
  if generated_ids.size(1) >= model.config.max_position_embeddings:
1233
  break
1234
 
 
 
 
 
 
 
 
1235
  new_token_ids = generated_ids[0][prompt_len:]
1236
  if len(new_token_ids) == 0:
1237
  return ""
1238
 
1239
  raw_text = tokenizer.decode(new_token_ids, skip_special_tokens=False)
1240
+ raw_text = re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', raw_text)
 
1241
  return raw_text.strip()
1242
 
1243
 
 
1246
  # ============================================================================
1247
 
1248
  def _clean_response(response: str) -> str:
 
 
 
1249
  if "<cot>" in response and "</cot>" in response:
1250
  response = response.split("</cot>", 1)[-1]
1251
  elif "<cot>" in response:
 
 
1252
  response = ""
1253
 
 
1254
  response = re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', response)
 
 
1255
  response = re.sub(r'<[^>]+>', '', response)
1256
+ # FIX: stricter role-marker pattern — only strips if WHOLE LINE is a role label
1257
+ response = re.sub(r'(?im)^(user\s*:|assistant\s*:)\s*$', '', response)
1258
+ # Also strip inline "user: " prefix but only at start of a line followed by content
1259
+ response = re.sub(r'(?im)^(user|assistant)\s*:\s*', '', response)
1260
 
 
 
 
 
1261
  for marker in ["memahami permintaan", "jawaban singkat", "penjelasan harus"]:
1262
  if marker in response:
1263
  response = response.split(marker)[0]
1264
 
 
1265
  response = re.sub(r'\n{2,}', '\n', response)
1266
  response = re.sub(r' {2,}', ' ', response)
 
 
1267
  response = re.sub(r'^[\s:!,.\-|]+', '', response)
 
1268
  return response.strip()
1269
 
1270
 
1271
  def _extract_thinking(raw: str) -> Tuple[str, str]:
 
 
 
1272
  raw = re.sub(r'\[(SEP|CLS|PAD|UNK|MASK)\]', '', raw)
1273
 
1274
  if "</cot>" in raw:
 
1275
  thinking_raw, answer_raw = raw.split("</cot>", 1)
1276
  thinking = re.sub(r'<[^>]+>', '', thinking_raw).strip()
1277
+ thinking = re.sub(r'(?im)^(user|assistant)\s*:\s*', '', thinking).strip()
1278
  answer = _clean_response(answer_raw)
 
1279
  elif "<cot>" in raw:
 
 
 
1280
  parts = raw.split("<cot>", 1)
1281
  thinking = _clean_response(parts[1]) if len(parts) > 1 else ""
 
1282
  answer = _clean_response(parts[0]) if parts[0].strip() else ""
 
1283
  else:
 
1284
  thinking = ""
1285
  answer = _clean_response(raw)
1286
 
 
1288
 
1289
 
1290
  # ============================================================================
1291
+ # INTERACTIVE CHAT — WITH MULTI-TURN HISTORY (FIX-I4)
1292
  # ============================================================================
1293
 
1294
  def interactive_chat(
 
1296
  tokenizer,
1297
  device: torch.device,
1298
  system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.",
1299
+ max_history_turns: int = 6,
1300
  ):
1301
+ """
1302
+ FIX-I4: Maintains a rolling conversation history.
1303
+ History is encoded as a flat context string, prepended to each new turn.
1304
+ The window is capped at max_history_turns to avoid context overflow.
1305
+ """
1306
  print("\n" + "=" * 80)
1307
+ print("INDONESIAN LLM — INTERACTIVE CHAT (KV-cache enabled, multi-turn)")
1308
  print("=" * 80)
1309
  print("Commands: exit/quit | clear | think (toggle CoT display)")
1310
  print(f"Persona : {system_prompt}")
1311
+ print(f"History : last {max_history_turns} turns")
1312
  print("=" * 80 + "\n")
1313
 
1314
  model.eval()
1315
+ show_thinking = False
1316
+ # Each entry: {"user": str, "assistant": str}
1317
+ history: List[Dict[str, str]] = []
1318
+
1319
+ def _build_prompt(user_input: str) -> str:
1320
+ """Build a prompt with rolling context window."""
1321
+ parts = []
1322
+ # System persona as a brief prefix
1323
+ parts.append(f"[Sistem: {system_prompt}]")
1324
+ # Recent history
1325
+ recent = history[-max_history_turns:]
1326
+ for turn in recent:
1327
+ parts.append(f"Pengguna: {turn['user']}")
1328
+ parts.append(f"Asisten: {turn['assistant']}")
1329
+ # Current turn — append CoT trigger
1330
+ parts.append(f"Pengguna: {user_input}")
1331
+ parts.append(f"Asisten: <cot>")
1332
+ return "\n".join(parts)
1333
 
1334
  while True:
1335
  try:
 
1340
  print("\nSelamat tinggal!")
1341
  break
1342
  if user_input.lower() in ['clear', 'bersihkan']:
1343
+ history.clear()
1344
  print("\nConversation cleared.")
1345
  continue
1346
  if user_input.lower() == 'think':
 
1348
  print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}")
1349
  continue
1350
 
1351
+ prompt = _build_prompt(user_input)
1352
  print("\nA:", end=" ", flush=True)
1353
 
 
1354
  response = generate_text(
1355
  model=model,
1356
  tokenizer=tokenizer,
 
1367
  if show_thinking and thinking:
1368
  print(f"[Thinking: {thinking}]")
1369
 
 
 
 
1370
  if answer:
1371
  final = answer
1372
  else:
1373
  final = _clean_response(response)
1374
  if not final and thinking:
 
1375
  sentences = [s.strip() for s in thinking.split('.') if s.strip()]
1376
  final = sentences[-1] if sentences else thinking[:200]
1377
 
 
1379
  final = "..."
1380
  print(final)
1381
 
1382
+ # FIX-I4: store turn in history (use clean answer)
1383
+ history.append({"user": user_input, "assistant": final})
1384
+
1385
  except KeyboardInterrupt:
1386
  print("\n\nDihentikan.")
1387
  break
 
1416
  print("No valid samples.")
1417
  return
1418
 
 
 
1419
  live_seed = int(time.time() * 1000) % (2**31)
1420
  random.seed(live_seed)
 
 
 
1421
 
1422
  samples = random.sample(all_samples, min(n, len(all_samples)))
1423
  model.eval()
1424
 
1425
  print(f"\n{'=' * 80}\nBENCHMARK ({len(samples)} samples)\n{'=' * 80}")
1426
 
1427
+ results = []
1428
+ acc = TokenLossAccumulator()
1429
 
1430
  for sample in samples:
1431
  inp = sample['input'].strip()
 
1437
  _, answer = _extract_thinking(raw)
1438
  answer_lower = answer.lower()
1439
 
 
1440
  passed = expected in answer_lower
1441
  if not passed:
1442
  exp_toks = set(expected.split())
 
1469
 
1470
  # ============================================================================
1471
  # SAVE / LOAD
1472
+ # FIX-T7: save_model keeps fp16 for inference; load_model does NOT upcast by default
1473
  # ============================================================================
1474
 
1475
+ def save_model(
1476
+ model: IndonesianLLM,
1477
+ config: ModelConfig,
1478
+ tokenizer_name: str,
1479
+ path: str,
1480
+ use_fp16: bool = True,
1481
+ ):
1482
  os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
1483
  state = model.state_dict()
1484
  if use_fp16:
 
1491
  'dtype': 'fp16' if use_fp16 else 'fp32',
1492
  }, path)
1493
  size_mb = os.path.getsize(path) / 1e6
1494
+ print(f"\nSaved: {path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB, "
1495
+ f"{model.count_parameters():,} params)")
1496
 
1497
 
1498
+ def load_model(path: str, device: torch.device, force_fp32_training: bool = False):
1499
+ """
1500
+ FIX-T7:
1501
+ - For inference (force_fp32_training=False): keep model in fp16 when saved as fp16.
1502
+ This halves VRAM usage during chat and benchmark.
1503
+ - For training continuation (force_fp32_training=True): upcast to fp32.
1504
+ """
1505
  if not os.path.exists(path):
1506
  raise FileNotFoundError(f"Checkpoint not found: {path}")
1507
  print(f"Loading: {path}")
 
1512
  dtype = ck.get('dtype', 'fp32')
1513
 
1514
  state = ck['model_state_dict']
1515
+
1516
+ # Only upcast when we need fp32 for training
1517
+ if force_fp32_training and dtype == 'fp16':
1518
  state = {k: v.float() if v.dtype == torch.float16 else v for k, v in state.items()}
1519
+ print(" [load_model] Upcasting fp16 -> fp32 for training")
1520
 
1521
+ # Derive intermediate_size from weights
 
 
1522
  gate_key = next((k for k in state if k.endswith('gate_proj.weight')), None)
1523
  if gate_key is not None:
1524
  inferred_intermediate = state[gate_key].shape[0]
1525
  if getattr(config, 'intermediate_size', -1) != inferred_intermediate:
1526
  print(f" [load_model] intermediate_size: config={getattr(config, 'intermediate_size', '?')} "
1527
+ f"-> overriding with {inferred_intermediate}")
1528
  config.intermediate_size = inferred_intermediate
1529
 
 
1530
  embed_key = next((k for k in state if k.endswith('embed_tokens.weight')), None)
1531
  if embed_key is not None:
1532
  inferred_vocab = state[embed_key].shape[0]
1533
  if config.vocab_size != inferred_vocab:
1534
  print(f" [load_model] vocab_size: config={config.vocab_size} "
1535
+ f"-> overriding with {inferred_vocab}")
1536
  config.vocab_size = inferred_vocab
1537
 
1538
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
1539
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
1540
 
1541
  model = IndonesianLLM(config)
1542
+ model.load_state_dict(state, strict=False)
1543
  model.to(device)
1544
 
1545
+ # Keep model in fp16 for inference if that's what was saved
1546
+ if not force_fp32_training and dtype == 'fp16':
1547
+ model = model.half()
1548
+ print(" [load_model] Keeping model in fp16 for inference (use force_fp32_training=True for training)")
1549
+
1550
  size_mb = os.path.getsize(path) / 1e6
1551
  print(f"Loaded ({dtype}, {size_mb:.1f} MB, {ck.get('model_params', model.count_parameters()):,} params)")
1552
  return model, tokenizer, config, {}
 
1590
  parser.add_argument('--ewc-lambda', type=float, default=5000.0)
1591
  parser.add_argument('--ewc-samples', type=int, default=2000)
1592
  parser.add_argument('--no-ewc', action='store_true')
1593
+ # FIX-T6: expose gradient checkpointing via CLI
1594
+ parser.add_argument('--grad-ckpt', action='store_true',
1595
+ help='Enable gradient checkpointing (saves ~50%% activation memory)')
1596
+ parser.add_argument('--max-history', type=int, default=6,
1597
+ help='Max conversation turns to keep in chat context')
1598
 
1599
  args = parser.parse_args()
1600
 
 
1607
  save_fp16 = not args.save_fp32
1608
  use_cot_training = not args.no_cot
1609
 
 
 
 
1610
  if args.train or args.finetune or args.continue_train:
1611
  set_seed(args.seed)
1612
  else:
 
1613
  import time
1614
  live_seed = int(time.time() * 1000) % (2**31)
1615
  random.seed(live_seed)
 
1617
  torch.manual_seed(live_seed)
1618
  if torch.cuda.is_available():
1619
  torch.cuda.manual_seed_all(live_seed)
1620
+
1621
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1622
  print(f"\nDevice: {device}")
1623
  if torch.cuda.is_available():
1624
  print(f" GPU: {torch.cuda.get_device_name(0)}")
1625
  print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
1626
 
 
1627
  if args.inspect_data:
1628
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
1629
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
 
1640
  print(f" Output: {s['output'][:120]}")
1641
  return
1642
 
 
1643
  if args.chat:
1644
+ model, tokenizer, _, _ = load_model(args.model, device, force_fp32_training=False)
1645
+ interactive_chat(model, tokenizer, device,
1646
+ system_prompt=args.system_prompt,
1647
+ max_history_turns=args.max_history)
1648
  return
1649
 
 
1650
  if args.benchmark:
1651
+ model, tokenizer, _, _ = load_model(args.model, device, force_fp32_training=False)
1652
  run_benchmark(model, tokenizer, device, dataset_path=args.dataset)
1653
  return
1654
 
 
1655
  if args.train:
1656
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
1657
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
1658
 
 
1659
  model_config = ModelConfig(
1660
  vocab_size = len(tokenizer),
1661
  hidden_size = args.hidden_size,
 
1666
  attention_dropout = 0.1,
1667
  residual_dropout = 0.1,
1668
  tie_word_embeddings = True,
1669
+ use_gradient_checkpointing = args.grad_ckpt,
1670
  )
1671
  print(f"\nModel config: {model_config}")
 
1672
 
1673
  model = IndonesianLLM(model_config)
 
 
 
1674
  print(f"Parameters: {model.count_parameters():,}")
1675
 
1676
  _ga = args.grad_accum or 32
1677
  train_config = TrainingConfig(
1678
+ dataset_path = args.dataset,
1679
+ num_epochs = args.epochs,
1680
+ batch_size = args.batch_size,
1681
+ gradient_accumulation_steps= _ga,
1682
+ max_seq_length = args.max_length,
1683
+ learning_rate = args.lr,
1684
+ warmup_steps = 500,
1685
+ use_fp16 = torch.cuda.is_available(),
1686
+ use_gradient_checkpointing = args.grad_ckpt,
1687
+ curriculum_stages = [128, 256, args.max_length],
1688
  )
1689
 
1690
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1691
+ train_config.max_seq_length, use_cot=use_cot_training,
1692
+ cot_ratio=args.cot_ratio)
1693
  model = train_model(model, dataset, train_config, device,
1694
  use_simple_curriculum=args.simple_curriculum)
1695
 
 
1705
  print(f"\nPrompt : {p}")
1706
  print(f"Generated: {generate_text(model, tokenizer, p, max_new_tokens=150, device=device)}\n")
1707
 
 
1708
  if args.finetune:
1709
+ model, tokenizer, model_config, _ = load_model(args.model, device, force_fp32_training=True)
1710
 
1711
  _ga = args.grad_accum or 32
1712
  train_config = TrainingConfig(
1713
+ dataset_path = args.dataset,
1714
+ num_epochs = args.epochs,
1715
+ batch_size = args.batch_size,
1716
+ gradient_accumulation_steps= _ga,
1717
+ max_seq_length = args.max_length,
1718
+ learning_rate = args.lr / 10,
1719
+ warmup_steps = 100,
1720
+ use_fp16 = torch.cuda.is_available(),
1721
+ use_gradient_checkpointing = args.grad_ckpt,
1722
+ curriculum_stages = [128, 256, args.max_length],
1723
  )
1724
 
1725
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1726
+ train_config.max_seq_length, use_cot=use_cot_training,
1727
+ cot_ratio=args.cot_ratio)
1728
  ewc_obj = None
1729
  if not args.no_ewc and args.ewc_lambda > 0:
1730
  print(f"\nComputing EWC Fisher (lambda={args.ewc_lambda}, n={args.ewc_samples})...")
1731
+ loader = _make_dataloader(dataset, args.batch_size, shuffle=True,
1732
+ pad_token_id=model.padding_idx, device_type=device.type)
 
1733
  train_config.ewc_lambda = args.ewc_lambda
1734
  train_config.ewc_samples = args.ewc_samples
1735
  ewc_obj = EWC(model, loader, device, n_samples=args.ewc_samples)
 
1744
  save_model(model, model_config, "indolem/indobert-base-uncased", out_path, use_fp16=save_fp16)
1745
  print(f"\nFinetuned model: {out_path}")
1746
 
 
1747
  if args.continue_train:
1748
+ model, tokenizer, model_config, _ = load_model(args.model, device, force_fp32_training=True)
1749
 
1750
+ effective_skip = (len([128, 256, args.max_length]) - 1) if args.simple_curriculum else args.skip_stages
1751
+ curriculum = [192, 320, args.max_length]
 
1752
 
1753
+ print(f"\nContinue-train LR: {args.lr:.2e} (skip {effective_skip} stages)")
1754
 
1755
  _ga = args.grad_accum or 32
1756
  train_config = TrainingConfig(
1757
+ dataset_path = args.dataset,
1758
+ num_epochs = args.epochs,
1759
+ batch_size = args.batch_size,
1760
+ gradient_accumulation_steps= _ga,
1761
+ max_seq_length = args.max_length,
1762
+ learning_rate = args.lr,
1763
+ warmup_steps = 500,
1764
+ use_fp16 = torch.cuda.is_available(),
1765
+ use_gradient_checkpointing = args.grad_ckpt,
1766
+ curriculum_stages = curriculum,
1767
+ plateau_patience = 2,
1768
+ plateau_factor = 0.6,
1769
+ plateau_min_delta = 0.01,
1770
  )
1771
 
1772
  dataset = IndonesianCoTDataset(train_config.dataset_path, tokenizer,
1773
+ train_config.max_seq_length, use_cot=use_cot_training,
1774
+ cot_ratio=args.cot_ratio)
1775
  model = train_model(model, dataset, train_config, device,
1776
  use_simple_curriculum=args.simple_curriculum,
1777
  is_continue=True,