| import os |
| import sys |
| import logging |
| import json |
| import pickle |
| from typing import Optional, Tuple, List, Dict, Any, Union |
| from pathlib import Path |
| from tqdm import tqdm |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
| from datasets import load_dataset, Dataset as HFDataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| get_linear_schedule_with_warmup, |
| PreTrainedModel |
| ) |
| from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| from transformers.models.qwen3.modeling_qwen3 import ( |
| Qwen3Model, |
| Qwen3ForCausalLM, |
| Qwen3PreTrainedModel, |
| Qwen3RMSNorm |
| ) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| force=True, |
| ) |
| logger = logging.getLogger("grouped_qwen3_training") |
|
|
|
|
| class GroupedInputMLPAdapter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| hidden_size = config.hidden_size |
| |
| self.grouped_processor = nn.Sequential( |
| nn.Linear(hidden_size, hidden_size * 2), |
| nn.SiLU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_size * 2, hidden_size), |
| nn.Dropout(0.1) |
| ) |
| |
| norm_eps = getattr(config, 'rms_norm_eps', 1e-6) |
| self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps) |
| |
| def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor: |
| processed = self.grouped_processor(grouped_embeds) |
| |
| output = self.layer_norm(grouped_embeds + processed) |
| |
| return output |
|
|
|
|
| class CustomQwen3ForCausalLM(Qwen3ForCausalLM): |
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.grouped_input_mlp = GroupedInputMLPAdapter(config) |
| |
| self.is_grouped_input_mode = False |
| self.grouped_cache_initialized = False |
| |
| self._init_grouped_weights() |
| |
| self._freeze_layers() |
| |
| def _init_grouped_weights(self): |
| def _init_weights(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| torch.nn.init.ones_(module.weight) |
| torch.nn.init.zeros_(module.bias) |
| |
| self.grouped_input_mlp.apply(_init_weights) |
| |
| def _freeze_layers(self): |
| for param in self.model.embed_tokens.parameters(): |
| param.requires_grad = False |
| |
| for i, layer in enumerate(self.model.layers): |
| if i == 0: |
| for param in layer.parameters(): |
| param.requires_grad = True |
| else: |
| for param in layer.parameters(): |
| param.requires_grad = False |
| |
| for param in self.model.norm.parameters(): |
| param.requires_grad = False |
| |
| for param in self.lm_head.parameters(): |
| param.requires_grad = False |
| |
| for param in self.grouped_input_mlp.parameters(): |
| param.requires_grad = True |
| |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in self.parameters()) |
| logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} " |
| f"({trainable_params/total_params*100:.2f}%)") |
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| grouped_inputs: Optional[torch.FloatTensor] = None, |
| is_prefill: Optional[bool] = None, |
| **kwargs |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| |
| if grouped_inputs is not None and is_prefill: |
| self.is_grouped_input_mode = True |
| |
| processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs) |
| |
| inputs_embeds = processed_grouped_inputs |
| input_ids = None |
| |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| if position_ids is None: |
| device = inputs_embeds.device |
| position_ids = torch.arange(seq_len, device=device, dtype=torch.long) |
| position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| |
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) |
| |
| self.grouped_cache_initialized = True |
| |
| elif not is_prefill and self.is_grouped_input_mode: |
| pass |
| else: |
| self.is_grouped_input_mode = False |
| |
| |
| outputs = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs |
| ) |
| |
| return outputs |
|
|
| class GroupedDataset(Dataset): |
| def __init__(self, dataset_path: str, tokenizer, max_response_length: int = 512, |
| validation_split: float = 0.1, is_validation: bool = False, |
| chunk_size: int = 1000, max_samples: Optional[int] = None): |
| self.dataset_path = dataset_path |
| self.tokenizer = tokenizer |
| self.max_response_length = max_response_length |
| self.validation_split = validation_split |
| self.is_validation = is_validation |
| self.chunk_size = chunk_size |
| self.max_samples = max_samples |
| |
| self._chunk_cache = {} |
| self._cache_size_limit = 3 |
| |
| self._build_index() |
| |
| def _build_index(self): |
| logger.info(f"Building index for {self.dataset_path}") |
| |
| with open(self.dataset_path, 'rb') as f: |
| data = pickle.load(f) |
| |
| valid_indices = [] |
| for i, item in enumerate(data): |
| if not item.get("error", False): |
| valid_indices.append(i) |
| |
| if self.max_samples and len(valid_indices) >= self.max_samples: |
| break |
| |
| total_valid = len(valid_indices) |
| |
| val_size = min(1000, int(self.validation_split * total_valid)) |
| train_size = total_valid - val_size |
| |
| if self.is_validation: |
| self.valid_indices = valid_indices[train_size:train_size + val_size] |
| self.total_samples = val_size |
| else: |
| self.valid_indices = valid_indices[:train_size] |
| self.total_samples = train_size |
| |
| self._full_data = data |
| |
| logger.info(f"{'Validation' if self.is_validation else 'Training'} dataset: {self.total_samples} samples") |
| |
| def _get_chunk_id(self, idx): |
| return idx // self.chunk_size |
| |
| def _load_chunk(self, chunk_id): |
| if chunk_id in self._chunk_cache: |
| return self._chunk_cache[chunk_id] |
| |
| start_idx = chunk_id * self.chunk_size |
| end_idx = min(start_idx + self.chunk_size, self.total_samples) |
| |
| chunk_data = {} |
| for i in range(start_idx, end_idx): |
| actual_idx = self.valid_indices[i] |
| chunk_data[i] = self._full_data[actual_idx] |
| |
| if len(self._chunk_cache) >= self._cache_size_limit: |
| oldest_chunk = min(self._chunk_cache.keys()) |
| del self._chunk_cache[oldest_chunk] |
| |
| self._chunk_cache[chunk_id] = chunk_data |
| return chunk_data |
| |
| def __len__(self): |
| return self.total_samples |
| |
| def __getitem__(self, idx): |
| if idx >= self.total_samples: |
| raise IndexError(f"Index {idx} out of range for dataset of size {self.total_samples}") |
| |
| chunk_id = self._get_chunk_id(idx) |
| chunk_data = self._load_chunk(chunk_id) |
| item = chunk_data[idx] |
| |
| return self._process_item(item) |
| |
| def _process_item(self, item): |
| grouped_embeds = item["inputs_embeds"] |
| if isinstance(grouped_embeds, torch.Tensor): |
| grouped_embeds = grouped_embeds.clone() |
| else: |
| grouped_embeds = torch.tensor(grouped_embeds) |
| |
| if grouped_embeds.dtype != torch.float32: |
| grouped_embeds = grouped_embeds.float() |
| |
| response = item["response"] |
| |
| response_tokens = self.tokenizer( |
| response, |
| max_length=self.max_response_length, |
| truncation=True, |
| padding=False, |
| return_tensors="pt" |
| ) |
| |
| response_input_ids = response_tokens["input_ids"].squeeze(0) |
| |
| return { |
| "grouped_inputs": grouped_embeds, |
| "response_input_ids": response_input_ids, |
| "response_text": response, |
| "input_text": item["input_text"], |
| } |
| |
| def cleanup(self): |
| self._chunk_cache.clear() |
| if hasattr(self, '_full_data'): |
| del self._full_data |
|
|
|
|
| def collate_fn(batch, tokenizer, pad_token_id=None): |
| if pad_token_id is None: |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id |
| |
| grouped_inputs = [item["grouped_inputs"] for item in batch] |
| response_input_ids = [item["response_input_ids"] for item in batch] |
| |
| max_grouped_len = max(gi.shape[0] for gi in grouped_inputs) |
| batch_size = len(grouped_inputs) |
| hidden_size = grouped_inputs[0].shape[-1] |
| |
| padded_grouped_inputs = torch.zeros(batch_size, max_grouped_len, hidden_size) |
| grouped_attention_mask = torch.zeros(batch_size, max_grouped_len, dtype=torch.long) |
| |
| for i, gi in enumerate(grouped_inputs): |
| seq_len = gi.shape[0] |
| padded_grouped_inputs[i, :seq_len] = gi |
| grouped_attention_mask[i, :seq_len] = 1 |
| |
| max_response_len = max(len(rid) for rid in response_input_ids) |
| padded_response_ids = torch.full((batch_size, max_response_len), pad_token_id, dtype=torch.long) |
| |
| for i, rid in enumerate(response_input_ids): |
| padded_response_ids[i, :len(rid)] = rid |
| |
| return { |
| "grouped_inputs": padded_grouped_inputs, |
| "grouped_attention_mask": grouped_attention_mask, |
| "response_input_ids": padded_response_ids, |
| "response_texts": [item["response_text"] for item in batch], |
| "input_texts": [item["input_text"] for item in batch], |
| } |
|
|
| class TrainingState: |
| def __init__(self, output_dir: Path): |
| self.output_dir = output_dir |
| self.state_file = output_dir / "training_state.json" |
| |
| def save_state(self, epoch: int, global_step: int, best_val_loss: float, |
| optimizer_state: Dict, scheduler_state: Dict): |
| """Save training state.""" |
| state = { |
| "epoch": epoch, |
| "global_step": global_step, |
| "best_val_loss": best_val_loss, |
| "optimizer_state": optimizer_state, |
| "scheduler_state": scheduler_state, |
| "completed_epochs": epoch |
| } |
| |
| with open(self.state_file, 'w') as f: |
| json.dump(state, f, indent=2, default=str) |
| |
| logger.info(f"Saved training state at epoch {epoch}, step {global_step}") |
| |
| def load_state(self): |
| if not self.state_file.exists(): |
| return None |
| |
| try: |
| with open(self.state_file, 'r') as f: |
| state = json.load(f) |
| logger.info(f"Loaded training state from epoch {state['epoch']}, step {state['global_step']}") |
| return state |
| except Exception as e: |
| logger.warning(f"Failed to load training state: {e}") |
| return None |
| |
| def get_latest_checkpoint(self): |
| state = self.load_state() |
| if state is None: |
| return None |
| |
| epoch = state["completed_epochs"] |
| checkpoint_path = self.output_dir / f"epoch_{epoch}" |
| |
| if checkpoint_path.exists(): |
| return checkpoint_path, state |
| else: |
| logger.warning(f"Checkpoint for epoch {epoch} not found") |
| return None |
|
|
| class GroupedTrainer: |
| def __init__( |
| self, |
| model_name: str = "Qwen/Qwen3-0.6B", |
| dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl", |
| output_dir: str = "./grouped_qwen3_checkpoint", |
| batch_size: int = 4, |
| learning_rate: float = 5e-4, |
| num_epochs: int = 3, |
| warmup_steps: int = 100, |
| max_grad_norm: float = 1.0, |
| save_steps: int = 500, |
| eval_steps: int = 500, |
| logging_steps: int = 50, |
| resume_training: bool = True, |
| debug: bool = False, |
| chunk_size: int = 1000, |
| max_samples: Optional[int] = None, |
| ): |
| self.model_name = model_name |
| self.dataset_path = dataset_path |
| self.output_dir = Path(output_dir) |
| self.batch_size = batch_size |
| self.learning_rate = learning_rate |
| self.num_epochs = num_epochs |
| self.warmup_steps = warmup_steps |
| self.max_grad_norm = max_grad_norm |
| self.save_steps = save_steps |
| self.eval_steps = eval_steps |
| self.logging_steps = logging_steps |
| self.resume_training = resume_training |
| self.debug = debug |
| self.chunk_size = chunk_size |
| self.max_samples = max_samples |
| |
| if self.debug: |
| logger.setLevel(logging.DEBUG) |
| |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer = None |
| self.model = None |
| self.train_dataset = None |
| self.val_dataset = None |
| |
| self.training_state = TrainingState(self.output_dir) |
| |
| def load_model_and_tokenizer(self): |
| logger.info(f"Loading tokenizer and model: {self.model_name}") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| config = Qwen3Config.from_pretrained(self.model_name) |
| |
| self.model = CustomQwen3ForCausalLM.from_pretrained( |
| self.model_name, |
| config=config, |
| torch_dtype=torch.float32, |
| attn_implementation="eager" |
| ).to(self.device) |
| |
| logger.info(f"Model loaded on {self.device}") |
| |
| def load_dataset(self, chunk_size: int = 1000, max_samples: Optional[int] = None): |
| logger.info(f"Loading streaming dataset from {self.dataset_path}") |
| |
| |
| self.train_dataset = GroupedDataset( |
| dataset_path=self.dataset_path, |
| tokenizer=self.tokenizer, |
| is_validation=False, |
| chunk_size=chunk_size, |
| max_samples=max_samples |
| ) |
| |
| self.val_dataset = GroupedDataset( |
| dataset_path=self.dataset_path, |
| tokenizer=self.tokenizer, |
| is_validation=True, |
| chunk_size=chunk_size, |
| max_samples=max_samples |
| ) |
| |
| logger.info(f"Train samples: {len(self.train_dataset)}") |
| logger.info(f"Val samples: {len(self.val_dataset)}") |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| memory_used = torch.cuda.memory_allocated() / 1024**3 |
| logger.info(f"GPU memory after dataset loading: {memory_used:.2f} GB") |
| |
| def cleanup_datasets(self): |
| if hasattr(self.train_dataset, 'cleanup'): |
| self.train_dataset.cleanup() |
| if hasattr(self.val_dataset, 'cleanup'): |
| self.val_dataset.cleanup() |
| |
| import gc |
| gc.collect() |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| def load_checkpoint(self, checkpoint_path: Path): |
| logger.info(f"Loading checkpoint from {checkpoint_path}") |
| |
| model_path = checkpoint_path / "pytorch_model.bin" |
| if not model_path.exists(): |
| model_path = checkpoint_path / "model.safetensors" |
| |
| if model_path.exists(): |
| state_dict = torch.load(model_path, map_location=self.device) |
| missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) |
| |
| if missing_keys: |
| logger.warning(f"Missing keys when loading checkpoint: {missing_keys}") |
| if unexpected_keys: |
| logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") |
| |
| logger.info("Model checkpoint loaded successfully") |
| else: |
| raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}") |
| |
| def compute_loss(self, batch, outputs): |
| logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0] |
| |
| target_ids = batch["response_input_ids"].to(self.device) |
| |
| logger.debug(f"Logits shape: {logits.shape}, Target shape: {target_ids.shape}") |
| |
| batch_size = target_ids.shape[0] |
| |
| if target_ids.shape[1] > 1: |
| labels = target_ids.clone() |
| pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
| labels[labels == pad_token_id] = -100 |
| |
| seq_len = min(logits.shape[1], labels.shape[1]) |
| logits_truncated = logits[:, :seq_len, :] |
| labels_truncated = labels[:, :seq_len] |
| |
| logger.debug(f"After truncation - Logits: {logits_truncated.shape}, Labels: {labels_truncated.shape}") |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct( |
| logits_truncated.reshape(-1, logits_truncated.size(-1)), |
| labels_truncated.reshape(-1) |
| ) |
| else: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct(logits.view(-1, logits.size(-1)), target_ids.view(-1)) |
| |
| return loss |
| |
| def training_step(self, batch, step): |
| self.model.train() |
| |
| if step < 5 and torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| memory_before = torch.cuda.memory_allocated() / 1024**3 |
| |
| grouped_inputs = batch["grouped_inputs"].to(self.device) |
| grouped_attention_mask = batch["grouped_attention_mask"].to(self.device) |
| response_input_ids = batch["response_input_ids"].to(self.device) |
| |
| batch_size = grouped_inputs.shape[0] |
| grouped_seq_len = grouped_inputs.shape[1] |
| response_seq_len = response_input_ids.shape[1] |
| |
| if self.debug: |
| logger.debug(f"Batch sizes - grouped: {grouped_inputs.shape}, response: {response_input_ids.shape}") |
| |
| grouped_outputs = self.model( |
| grouped_inputs=grouped_inputs, |
| attention_mask=grouped_attention_mask, |
| is_prefill=True, |
| use_cache=True, |
| return_dict=True |
| ) |
| |
| if response_seq_len > 1: |
| response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long() |
| |
| response_outputs = self.model( |
| input_ids=response_input_ids[:, :-1], |
| attention_mask=response_attention_mask[:, :-1], |
| past_key_values=grouped_outputs.past_key_values, |
| use_cache=False, |
| return_dict=True |
| ) |
| |
| logits = response_outputs.logits |
| labels = response_input_ids[:, 1:] |
| |
| labels = labels.clone() |
| pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
| labels[labels == pad_token_id] = -100 |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) |
| |
| else: |
| loss = torch.tensor(0.0, requires_grad=True, device=self.device) |
| |
| if step < 5 and torch.cuda.is_available(): |
| memory_after = torch.cuda.memory_allocated() / 1024**3 |
| memory_peak = torch.cuda.max_memory_allocated() / 1024**3 |
| logger.info(f"Step {step} Memory: {memory_before:.2f}GB β {memory_after:.2f}GB (Peak: {memory_peak:.2f}GB)") |
| |
| if memory_peak > 20.0: |
| logger.warning("High memory usage detected! Consider reducing batch_size") |
| |
| class MockOutputs: |
| def __init__(self, loss, logits): |
| self.loss = loss |
| self.logits = logits |
| |
| outputs = MockOutputs(loss, response_outputs.logits if 'response_outputs' in locals() else grouped_outputs.logits) |
| |
| return loss, outputs |
| |
| def validation_step(self, batch): |
| """Single validation step.""" |
| self.model.eval() |
| |
| with torch.no_grad(): |
| grouped_inputs = batch["grouped_inputs"].to(self.device) |
| grouped_attention_mask = batch["grouped_attention_mask"].to(self.device) |
| response_input_ids = batch["response_input_ids"].to(self.device) |
| |
| grouped_outputs = self.model( |
| grouped_inputs=grouped_inputs, |
| attention_mask=grouped_attention_mask, |
| is_prefill=True, |
| use_cache=True, |
| return_dict=True |
| ) |
| |
| if response_input_ids.shape[1] > 1: |
| response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long() |
| |
| response_outputs = self.model( |
| input_ids=response_input_ids[:, :-1], |
| attention_mask=response_attention_mask[:, :-1], |
| past_key_values=grouped_outputs.past_key_values, |
| use_cache=False, |
| return_dict=True |
| ) |
| |
| logits = response_outputs.logits |
| labels = response_input_ids[:, 1:] |
| |
| labels = labels.clone() |
| pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
| labels[labels == pad_token_id] = -100 |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) |
| else: |
| loss = torch.tensor(0.0, device=self.device) |
| |
| return loss.item() |
| |
| def save_epoch_checkpoint(self, epoch: int, global_step: int, is_best: bool = False): |
| checkpoint_name = f"epoch_{epoch}" |
| if is_best: |
| checkpoint_name += "_best" |
| |
| checkpoint_dir = self.output_dir / checkpoint_name |
| checkpoint_dir.mkdir(exist_ok=True) |
| |
| torch.save(self.model.state_dict(), checkpoint_dir / "pytorch_model.bin") |
| |
| self.model.config.save_pretrained(checkpoint_dir) |
| |
| self.tokenizer.save_pretrained(checkpoint_dir) |
| |
| metadata = { |
| "epoch": epoch, |
| "global_step": global_step, |
| "model_name": self.model_name, |
| "learning_rate": self.learning_rate, |
| "batch_size": self.batch_size, |
| "is_best": is_best, |
| "model_class": "CustomQwen3ForCausalLM" |
| } |
| |
| with open(checkpoint_dir / "epoch_metadata.json", 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| logger.info(f"Saved epoch checkpoint: {checkpoint_dir}") |
| return checkpoint_dir |
| |
| def train(self): |
| logger.info("Starting training...") |
| |
| train_loader = DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| collate_fn=lambda batch: collate_fn(batch, self.tokenizer), |
| num_workers=0 |
| ) |
| |
| val_loader = DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| collate_fn=lambda batch: collate_fn(batch, self.tokenizer), |
| num_workers=0 |
| ) |
| |
| optimizer = AdamW( |
| [p for p in self.model.parameters() if p.requires_grad], |
| lr=self.learning_rate, |
| weight_decay=0.01 |
| ) |
| |
| total_steps = len(train_loader) * self.num_epochs |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=self.warmup_steps, |
| num_training_steps=total_steps |
| ) |
| |
| start_epoch = 0 |
| global_step = 0 |
| best_val_loss = float('inf') |
| |
| if self.resume_training: |
| checkpoint_info = self.training_state.get_latest_checkpoint() |
| if checkpoint_info is not None: |
| checkpoint_path, state = checkpoint_info |
| |
| self.load_checkpoint(checkpoint_path) |
| |
| start_epoch = state["completed_epochs"] |
| global_step = state["global_step"] |
| best_val_loss = state["best_val_loss"] |
| |
| if "optimizer_state" in state and state["optimizer_state"]: |
| try: |
| optimizer.load_state_dict(state["optimizer_state"]) |
| except Exception as e: |
| logger.warning(f"Failed to load optimizer state: {e}") |
| |
| if "scheduler_state" in state and state["scheduler_state"]: |
| try: |
| scheduler.load_state_dict(state["scheduler_state"]) |
| except Exception as e: |
| logger.warning(f"Failed to load scheduler state: {e}") |
| |
| logger.info(f"Resumed training from epoch {start_epoch + 1}") |
| |
| for epoch in range(start_epoch, self.num_epochs): |
| logger.info(f"Epoch {epoch + 1}/{self.num_epochs}") |
| |
| epoch_train_loss = 0 |
| train_steps = 0 |
| |
| progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}") |
| |
| for batch_idx, batch in enumerate(progress_bar): |
| try: |
| loss, outputs = self.training_step(batch, global_step) |
| |
| loss.backward() |
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| epoch_train_loss += loss.item() |
| train_steps += 1 |
| global_step += 1 |
| |
| progress_bar.set_postfix({ |
| 'loss': f'{loss.item():.4f}', |
| 'lr': f'{scheduler.get_last_lr()[0]:.2e}' |
| }) |
| |
| if global_step % self.logging_steps == 0: |
| avg_loss = epoch_train_loss / train_steps |
| logger.info(f"Step {global_step}: train_loss={avg_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}") |
| |
| if global_step % self.eval_steps == 0: |
| val_loss = self.validate(val_loader) |
| logger.info(f"Step {global_step}: val_loss={val_loss:.4f}") |
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| best_checkpoint = self.save_epoch_checkpoint(epoch, global_step, is_best=True) |
| logger.info(f"New best validation loss: {val_loss:.4f}") |
| |
| except Exception as e: |
| logger.error(f"Error in training step {global_step}: {e}") |
| continue |
| |
| val_loss = self.validate(val_loader) |
| avg_train_loss = epoch_train_loss / train_steps if train_steps > 0 else 0 |
| |
| logger.info(f"Epoch {epoch + 1} completed:") |
| logger.info(f" Average train loss: {avg_train_loss:.4f}") |
| logger.info(f" Validation loss: {val_loss:.4f}") |
| |
| is_best = val_loss < best_val_loss |
| if is_best: |
| best_val_loss = val_loss |
| |
| checkpoint_dir = self.save_epoch_checkpoint(epoch, global_step, is_best=is_best) |
| |
| self.training_state.save_state( |
| epoch=epoch, |
| global_step=global_step, |
| best_val_loss=best_val_loss, |
| optimizer_state=optimizer.state_dict(), |
| scheduler_state=scheduler.state_dict() |
| ) |
| |
| logger.info(f"Epoch {epoch + 1} checkpoint and state saved") |
| |
| logger.info(f"Training completed! Best validation loss: {best_val_loss:.4f}") |
| |
| final_checkpoint = self.save_epoch_checkpoint(self.num_epochs - 1, global_step, is_best=False) |
| logger.info(f"Final checkpoint saved: {final_checkpoint}") |
| |
| def validate(self, val_loader): |
| self.model.eval() |
| total_loss = 0 |
| num_batches = 0 |
| |
| with torch.no_grad(): |
| for batch in tqdm(val_loader, desc="Validation"): |
| try: |
| loss = self.validation_step(batch) |
| total_loss += loss |
| num_batches += 1 |
| except Exception as e: |
| logger.warning(f"Error in validation step: {e}") |
| continue |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else float('inf') |
| self.model.train() |
| return avg_loss |
| |
| def run(self): |
| try: |
| self.load_model_and_tokenizer() |
| |
| self.load_dataset( |
| chunk_size=self.chunk_size, |
| max_samples=self.max_samples |
| ) |
| |
| self.train() |
| |
| logger.info("Training pipeline completed successfully!") |
| |
| self.cleanup_datasets() |
| |
| except Exception as e: |
| logger.error(f"Training pipeline failed: {e}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| |
| try: |
| self.cleanup_datasets() |
| except: |
| pass |
| |
| raise |
|
|
| def load_trained_model(checkpoint_path: str, model_name: str = "Qwen/Qwen3-0.6B"): |
| logger.info(f"Loading trained model from {checkpoint_path}") |
| |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
| |
| config = Qwen3Config.from_pretrained(checkpoint_path) |
| |
| model = CustomQwen3ForCausalLM(config) |
| |
| model_path = Path(checkpoint_path) / "pytorch_model.bin" |
| if not model_path.exists(): |
| model_path = Path(checkpoint_path) / "model.safetensors" |
| |
| if not model_path.exists(): |
| raise FileNotFoundError(f"No model weights found in {checkpoint_path}") |
| |
| state_dict = torch.load(model_path, map_location="cpu") |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| |
| if missing_keys: |
| logger.warning(f"Missing keys when loading model: {missing_keys}") |
| if unexpected_keys: |
| logger.warning(f"Unexpected keys when loading model: {unexpected_keys}") |
| |
| model = model.eval().to(torch.float32) |
| |
| return model, tokenizer |
|
|
|
|
| def generate_with_grouped_input( |
| model, |
| tokenizer, |
| grouped_input: torch.Tensor, |
| max_length: int = 512, |
| temperature: float = 0.7, |
| do_sample: bool = True |
| ): |
| device = model.device |
| model_dtype = next(model.parameters()).dtype |
| |
| grouped_input = grouped_input.to(device=device, dtype=model_dtype) |
| |
| if grouped_input.ndim == 2: |
| grouped_input = grouped_input.unsqueeze(0) |
| |
| logger.debug(f"Grouped input shape: {grouped_input.shape}, dtype: {grouped_input.dtype}") |
| logger.debug(f"Model dtype: {model_dtype}, device: {device}") |
| |
| with torch.no_grad(): |
| try: |
| outputs = model( |
| grouped_inputs=grouped_input, |
| is_prefill=True, |
| use_cache=True, |
| return_dict=True |
| ) |
| except Exception as e: |
| logger.error(f"Error in prefill phase: {e}") |
| raise |
| |
| if hasattr(outputs, 'logits') and outputs.logits is not None: |
| next_token_logits = outputs.logits[:, -1, :] |
| elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: |
| last_hidden_state = outputs.hidden_states[-1] if isinstance(outputs.hidden_states, (list, tuple)) else outputs.hidden_states |
| next_token_logits = model.lm_head(last_hidden_state[:, -1, :]) |
| else: |
| raise RuntimeError("Could not extract logits from model output") |
| |
| if do_sample: |
| next_token_logits = next_token_logits / temperature |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
| generated_ids = next_token |
| past_key_values = getattr(outputs, 'past_key_values', None) |
| |
| for step in range(max_length - 1): |
| with torch.no_grad(): |
| try: |
| outputs = model( |
| input_ids=next_token, |
| past_key_values=past_key_values, |
| use_cache=True, |
| return_dict=True |
| ) |
| except Exception as e: |
| logger.error(f"Error in generation step {step}: {e}") |
| break |
| |
| if hasattr(outputs, 'logits'): |
| next_token_logits = outputs.logits[:, -1, :] |
| else: |
| logger.warning("No logits in generation output, stopping generation") |
| break |
| |
| if do_sample: |
| next_token_logits = next_token_logits / temperature |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| |
| generated_ids = torch.cat([generated_ids, next_token], dim=1) |
| past_key_values = getattr(outputs, 'past_key_values', None) |
| |
| if next_token.item() == tokenizer.eos_token_id: |
| break |
| |
| generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| return generated_text |
|
|
| def main(): |
| config = { |
| "model_name": "Qwen/Qwen3-0.6B", |
| "dataset_path": "./processed_qwen3_dataset/processed_dataset.pkl", |
| "output_dir": "./grouped_qwen3_checkpoint", |
| "batch_size": 12, |
| "learning_rate": 5e-4, |
| "num_epochs": 3, |
| "warmup_steps": 500, |
| "max_grad_norm": 1.0, |
| "save_steps": 1000, |
| "eval_steps": 1000, |
| "logging_steps": 100, |
| "resume_training": True, |
| "debug": False, |
| |
| "chunk_size": 2000, |
| "max_samples": None, |
| } |
| |
| logger.info("="*60) |
| logger.info("GROUPED QWEN3 TRAINING CONFIGURATION (STREAMING)") |
| logger.info("="*60) |
| for key, value in config.items(): |
| logger.info(f"{key}: {value}") |
| logger.info("="*60) |
| |
| if torch.cuda.is_available(): |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| logger.info(f"VRAM Total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
| |
| import psutil |
| ram_usage = psutil.virtual_memory() |
| logger.info(f"System RAM: {ram_usage.used / 1024**3:.1f} GB / {ram_usage.total / 1024**3:.1f} GB ({ram_usage.percent:.1f}%)") |
| |
| trainer = GroupedTrainer(**config) |
| trainer.run() |
|
|
|
|
| def inference_by_id(sample_id: int, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best", |
| dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl", |
| max_length: int = 512, temperature: float = 0.7, do_sample: bool = True): |
| """Run inference on a specific sample ID from the dataset.""" |
| logger.info(f"Running inference on sample ID: {sample_id}") |
| |
| try: |
| model, tokenizer = load_trained_model(checkpoint_path) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| logger.info(f"Model loaded from {checkpoint_path}") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| return None |
| |
| try: |
| logger.info(f"Loading sample {sample_id} from dataset...") |
| with open(dataset_path, 'rb') as f: |
| processed_data = pickle.load(f) |
| |
| if sample_id >= len(processed_data): |
| logger.error(f"Sample ID {sample_id} is out of range. Dataset has {len(processed_data)} samples.") |
| return None |
| |
| sample = processed_data[sample_id] |
| |
| if sample.get("error", False): |
| logger.error(f"Sample {sample_id} has errors and cannot be used for inference.") |
| return None |
| |
| except Exception as e: |
| logger.error(f"Failed to load dataset: {e}") |
| return None |
| |
| grouped_embeds_raw = sample["inputs_embeds"] |
| if isinstance(grouped_embeds_raw, torch.Tensor): |
| grouped_input = grouped_embeds_raw.detach().clone().float() |
| else: |
| grouped_input = torch.tensor(grouped_embeds_raw, dtype=torch.float32) |
| |
| original_input = sample["input_text"] |
| expected_response = sample["response"] |
| |
| print("\n" + "="*80) |
| print(f"INFERENCE ON SAMPLE ID: {sample_id}") |
| print("="*80) |
| print(f"π ORIGINAL REQUEST:") |
| print(f"{original_input}") |
| print("\n" + "-"*80) |
| print(f"π― EXPECTED RESPONSE:") |
| print(f"{expected_response}") |
| print("\n" + "-"*80) |
| print(f"π€ MODEL GENERATED RESPONSE:") |
| |
| try: |
| generated_text = generate_with_grouped_input( |
| model=model, |
| tokenizer=tokenizer, |
| grouped_input=grouped_input, |
| max_length=max_length, |
| temperature=temperature, |
| do_sample=do_sample |
| ) |
| |
| print(f"{generated_text}") |
| print("\n" + "="*80) |
| |
| expected_words = expected_response.split() |
| generated_words = generated_text.split() |
| |
| print(f"π METRICS:") |
| print(f"Expected length: {len(expected_words)} words") |
| print(f"Generated length: {len(generated_words)} words") |
| print(f"Temperature: {temperature}") |
| print(f"Max length: {max_length}") |
| print("="*80) |
| |
| return { |
| "sample_id": sample_id, |
| "original_input": original_input, |
| "expected_response": expected_response, |
| "generated_response": generated_text, |
| "expected_length": len(expected_words), |
| "generated_length": len(generated_words) |
| } |
| |
| except Exception as e: |
| logger.error(f"Failed to generate response: {e}") |
| print(f"β GENERATION FAILED: {e}") |
| print("="*80) |
| return None |
|
|
|
|
| def batch_inference(sample_ids: List[int], checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best", |
| dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl", |
| max_length: int = 512, temperature: float = 0.7, do_sample: bool = True): |
| """Run inference on multiple sample IDs.""" |
| logger.info(f"Running batch inference on {len(sample_ids)} samples") |
| |
| results = [] |
| for sample_id in sample_ids: |
| result = inference_by_id( |
| sample_id=sample_id, |
| checkpoint_path=checkpoint_path, |
| dataset_path=dataset_path, |
| max_length=max_length, |
| temperature=temperature, |
| do_sample=do_sample |
| ) |
| if result: |
| results.append(result) |
| |
| print("\n" + "π " + "-"*78 + " π\n") |
| |
| print("\n" + "="*80) |
| print(f"π BATCH INFERENCE SUMMARY") |
| print("="*80) |
| print(f"Total samples processed: {len(results)}") |
| if results: |
| avg_expected_len = sum(r["expected_length"] for r in results) / len(results) |
| avg_generated_len = sum(r["generated_length"] for r in results) / len(results) |
| print(f"Average expected length: {avg_expected_len:.1f} words") |
| print(f"Average generated length: {avg_generated_len:.1f} words") |
| print("="*80) |
| |
| return results |
|
|
|
|
| def random_inference(num_samples: int = 3, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best", |
| dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl", |
| max_length: int = 512, temperature: float = 0.7, do_sample: bool = True): |
| """Run inference on random samples from the dataset.""" |
| import random |
| |
| try: |
| with open(dataset_path, 'rb') as f: |
| processed_data = pickle.load(f) |
| |
| |
| valid_indices = [i for i, item in enumerate(processed_data) if not item.get("error", False)] |
| |
| if len(valid_indices) < num_samples: |
| logger.warning(f"Only {len(valid_indices)} valid samples available, using all of them") |
| num_samples = len(valid_indices) |
| |
| |
| random_ids = random.sample(valid_indices, num_samples) |
| |
| logger.info(f"Selected random sample IDs: {random_ids}") |
| |
| |
| return batch_inference( |
| sample_ids=random_ids, |
| checkpoint_path=checkpoint_path, |
| dataset_path=dataset_path, |
| max_length=max_length, |
| temperature=temperature, |
| do_sample=do_sample |
| ) |
| |
| except Exception as e: |
| logger.error(f"Failed to load dataset for random sampling: {e}") |
| return None |
|
|
|
|
| def interactive_inference(checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best", |
| dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl"): |
| """Interactive inference mode where user can input sample IDs.""" |
| print("\n" + "="*80) |
| print("π€ INTERACTIVE INFERENCE MODE") |
| print("="*80) |
| print("Commands:") |
| print(" <number> - Run inference on sample ID") |
| print(" random <n> - Run inference on n random samples (default: 3)") |
| print(" batch <ids> - Run inference on multiple IDs (e.g., 'batch 1,5,10')") |
| print(" quit - Exit") |
| print("="*80) |
| |
| while True: |
| try: |
| user_input = input("\nπ Enter command: ").strip().lower() |
| |
| if user_input in ['quit', 'exit', 'q']: |
| print("π Goodbye!") |
| break |
| elif user_input.startswith('random'): |
| parts = user_input.split() |
| num_samples = int(parts[1]) if len(parts) > 1 else 3 |
| random_inference(num_samples=num_samples, checkpoint_path=checkpoint_path, dataset_path=dataset_path) |
| elif user_input.startswith('batch'): |
| parts = user_input.split(maxsplit=1) |
| if len(parts) > 1: |
| ids_str = parts[1] |
| sample_ids = [int(x.strip()) for x in ids_str.split(',')] |
| batch_inference(sample_ids=sample_ids, checkpoint_path=checkpoint_path, dataset_path=dataset_path) |
| else: |
| print("β Please provide sample IDs: batch 1,5,10") |
| elif user_input.isdigit(): |
| sample_id = int(user_input) |
| inference_by_id(sample_id=sample_id, checkpoint_path=checkpoint_path, dataset_path=dataset_path) |
| else: |
| print("β Invalid command. Try a number, 'random', 'batch', or 'quit'") |
| |
| except ValueError: |
| print("β Invalid input. Please enter a valid number or command.") |
| except KeyboardInterrupt: |
| print("\nπ Goodbye!") |
| break |
| except Exception as e: |
| print(f"β Error: {e}") |
|
|
|
|
| def test_inference(): |
| logger.info("Running inference tests...") |
| |
| test_ids = [0, 1, 2, 100, 500] |
| |
| print("\nπ§ͺ TESTING INFERENCE ON PREDEFINED SAMPLES") |
| results = batch_inference( |
| sample_ids=test_ids, |
| max_length=300, |
| temperature=0.7, |
| do_sample=True |
| ) |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Grouped Qwen3 Training and Inference") |
| parser.add_argument("--mode", choices=["train", "test", "inference", "interactive", "random"], |
| default="train", help="Mode to run") |
| parser.add_argument("--sample_id", type=int, help="Sample ID for inference mode") |
| parser.add_argument("--sample_ids", type=str, help="Comma-separated sample IDs for batch inference") |
| parser.add_argument("--num_samples", type=int, default=3, help="Number of random samples for random mode") |
| parser.add_argument("--checkpoint", type=str, default="./grouped_qwen3_checkpoint/epoch_2_best", |
| help="Path to model checkpoint") |
| parser.add_argument("--dataset", type=str, default="./processed_qwen3_dataset/processed_dataset.pkl", |
| help="Path to dataset") |
| parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length") |
| parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature") |
| |
| args = parser.parse_args() |
| |
| if args.mode == "train": |
| main() |
| elif args.mode == "test": |
| test_inference() |
| elif args.mode == "inference": |
| if args.sample_id is not None: |
| inference_by_id( |
| sample_id=args.sample_id, |
| checkpoint_path=args.checkpoint, |
| dataset_path=args.dataset, |
| max_length=args.max_length, |
| temperature=args.temperature |
| ) |
| elif args.sample_ids is not None: |
| sample_ids = [int(x.strip()) for x in args.sample_ids.split(',')] |
| batch_inference( |
| sample_ids=sample_ids, |
| checkpoint_path=args.checkpoint, |
| dataset_path=args.dataset, |
| max_length=args.max_length, |
| temperature=args.temperature |
| ) |
| else: |
| print("β Please provide --sample_id or --sample_ids for inference mode") |
| elif args.mode == "interactive": |
| interactive_inference( |
| checkpoint_path=args.checkpoint, |
| dataset_path=args.dataset |
| ) |
| elif args.mode == "random": |
| random_inference( |
| num_samples=args.num_samples, |
| checkpoint_path=args.checkpoint, |
| dataset_path=args.dataset, |
| max_length=args.max_length, |
| temperature=args.temperature |
| ) |