| import torch |
| import torch.nn as nn |
| from torch.cuda.amp import autocast |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| import math, os, sys, json, glob, time, random |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from transformers import AutoTokenizer |
| from distributed_shampoo import AdamGraftingConfig, DistributedShampoo |
| from cut_cross_entropy import linear_cross_entropy |
| from torch.nn.utils import clip_grad_norm_ |
| from utils.trainutils import count_parameters_layerwise, save_checkpoint, TBLogger |
|
|
| from llama_modeling.front_end import LlamaForCausalLM |
| from llama_modeling.config import LlamaConfig |
|
|
| class JSONLDataset(Dataset): |
| def __init__(self, directory_path, tokenizer, seq_length=1024, |
| text_key="text", max_files=None, batch_size=1000, |
| pad_token_id=0): |
| self.seq_length = seq_length |
| self.tokenizer = tokenizer |
| self.pad_token_id = pad_token_id |
| self.sequences = [] |
|
|
| files = glob.glob(os.path.join(directory_path, "*.jsonl")) |
| if max_files is not None: |
| files = files[:max_files] |
|
|
| text_batch = [] |
| for file_idx, file_path in enumerate(files): |
| with open(file_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| try: |
| data = json.loads(line) |
| text = data.get(text_key, "") |
| if len(text) >= 100: |
| text_batch.append(text) |
| |
| if len(text_batch) >= batch_size: |
| self._process_batch(text_batch) |
| text_batch = [] |
| except: |
| continue |
| |
| if text_batch: |
| self._process_batch(text_batch) |
|
|
| if self.sequences: |
| self.sequences = torch.tensor(self.sequences, dtype=torch.long) |
| else: |
| self.sequences = torch.empty((0, seq_length), dtype=torch.long) |
|
|
| def _process_batch(self, texts): |
| encoded = self.tokenizer( |
| texts, |
| add_special_tokens=False, |
| truncation=True, |
| padding=False, |
| return_attention_mask=False, |
| return_tensors=None |
| )['input_ids'] |
|
|
| mlen = 0 |
| for token_ids in encoded: |
| for i in range(0, len(token_ids), self.seq_length): |
| chunk = token_ids[i:i+self.seq_length] |
| |
| |
| if len(chunk) < self.seq_length: |
| chunk += [self.pad_token_id] * (self.seq_length - len(chunk)) |
| |
| self.sequences.append(chunk) |
| mlen = max(mlen, len(chunk)) |
| |
| print("MAX: ", mlen) |
|
|
| def __len__(self): |
| return len(self.sequences) |
|
|
| def __getitem__(self, idx): |
| return self.sequences[idx] |
|
|
| def train_model(model, train_loader, optimizer, device, epochs=5, forward_dtype=torch.float32): |
| model.train() |
| criterion = nn.CrossEntropyLoss() |
| scaler = torch.amp.GradScaler("cuda") |
| |
| logger = TBLogger(log_dir=f'logs/run-{time.time()}') |
| |
| total_steps = len(train_loader) * epochs |
| scheduler = CosineAnnealingLR( |
| optimizer, |
| T_max=total_steps, |
| eta_min=5e-6 |
| ) |
| |
| model = torch.compile( |
| model, |
| ) |
| |
| global_step = 0 |
| for epoch in range(epochs): |
| running_loss = 0.0 |
| total_batches = 0 |
| progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}') |
| |
| for batch_idx, data in enumerate(progress_bar): |
| data = data.to(device) |
| optimizer.zero_grad(set_to_none=True) |
| |
| with torch.autocast(device_type='cuda', dtype=forward_dtype): |
| hidden_states, classifier_weights = model(data) |
| |
| loss = linear_cross_entropy( |
| hidden_states, |
| classifier_weights, |
| data, |
| shift=True, |
| reduction="mean" |
| ) |
| |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| clip_grad_norm_(model.parameters(), max_norm=1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| |
| running_loss += loss.item() |
| total_batches += 1 |
| global_step += 1 |
| avg_loss = running_loss / total_batches |
| perplexity = math.exp(min(avg_loss, 100)) |
|
|
| progress_bar.set_postfix({ |
| 'loss': f'{avg_loss:.4f}', |
| 'ppl': f'{perplexity:.2f}' |
| }) |
| |
| metrics = { |
| 'loss': loss.item(), |
| 'perplexity': perplexity, |
| 'learning_rate': optimizer.param_groups[0]['lr'], |
| 'batch_size': data.size(0) |
| } |
|
|
| logger.log(metrics, step=global_step, model=model, grad_checking=True) |
|
|
| if batch_idx % 100 == 0: |
| print(f'\nBatch {batch_idx}/{len(train_loader)}: ' |
| f'Loss: {avg_loss:.4f}, ' |
| f'Perplexity: {perplexity:.2f}, ' |
| f'Batches Processed: {total_batches}') |
|
|
| epoch_loss = running_loss / total_batches |
| epoch_ppl = math.exp(min(epoch_loss, 100)) |
| print(f'\nEpoch {epoch+1} Summary:') |
| print(f'Average Loss: {epoch_loss:.4f}') |
| print(f'Perplexity: {epoch_ppl:.2f}') |
| print(f'Total Batches Processed: {total_batches}\n') |
| |
| save_checkpoint(model, f'epoch_{epoch+1}.safetensors') |
|
|
| def sample_examples(dataset, tokenizer, num_samples=5): |
| if len(dataset) == 0: |
| print("The dataset is empty.") |
| return |
| |
| num_samples = min(num_samples, len(dataset)) |
| |
| sampled_indices = random.sample(range(len(dataset)), num_samples) |
| |
| for i, idx in enumerate(sampled_indices): |
| sequence = dataset[idx] |
| print(f"Sample {i + 1} (Index {idx}):") |
| print(sequence) |
| decoded_text = tokenizer.decode(sequence, skip_special_tokens=False, decode_special_tokens=False) |
| print(decoded_text) |
| print("-" * 40) |
| |
| def main(): |
| BATCH_SIZE = 36 |
| SEQ_LENGTH = 512 |
| EPOCHS = 3 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct") |
|
|
| config_path = "config.json" |
| with open(config_path) as f: |
| config_dict = json.load(f) |
| config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__}) |
|
|
| model = LlamaForCausalLM(config).to("cuda") |
|
|
| dataset = JSONLDataset( |
| directory_path="./Data_big", |
| tokenizer=tokenizer, |
| seq_length=SEQ_LENGTH, |
| text_key="text", |
| max_files=None, |
| ) |
| |
| train_loader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=4, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
| optimizer = DistributedShampoo( |
| model.parameters(), |
| lr=0.0001, |
| betas=(0.9, 0.999), |
| epsilon=1e-12, |
| weight_decay=1e-05, |
| max_preconditioner_dim=2048, |
| precondition_frequency=100, |
| start_preconditioning_step=250, |
| use_decoupled_weight_decay=False, |
| grafting_config=AdamGraftingConfig( |
| beta2=0.999, |
| epsilon=1e-12, |
| ), |
| ) |
| |
| print("*"*100) |
| torch.set_float32_matmul_precision('high') |
| |
| count_parameters_layerwise(model) |
|
|
| train_model(model, train_loader, optimizer, DEVICE, EPOCHS, forward_dtype=torch.bfloat16) |
|
|
| if __name__ == "__main__": |
| main() |