English
context-merging / train_custom_qwen3.py
kkolomeitsev's picture
Upload folder using huggingface_hub
d872c55 verified
import os
import sys
import logging
import json
import pickle
from typing import Optional, Tuple, List, Dict, Any, Union
from pathlib import Path
from tqdm import tqdm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from datasets import load_dataset, Dataset as HFDataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
get_linear_schedule_with_warmup,
PreTrainedModel
)
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Model,
Qwen3ForCausalLM,
Qwen3PreTrainedModel,
Qwen3RMSNorm
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
logger = logging.getLogger("grouped_qwen3_training")
class GroupedInputMLPAdapter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
hidden_size = config.hidden_size
self.grouped_processor = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2),
nn.SiLU(), # Using SiLU activation like Qwen3
nn.Dropout(0.1),
nn.Linear(hidden_size * 2, hidden_size),
nn.Dropout(0.1)
)
norm_eps = getattr(config, 'rms_norm_eps', 1e-6)
self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps)
def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor:
processed = self.grouped_processor(grouped_embeds)
output = self.layer_norm(grouped_embeds + processed)
return output
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.grouped_input_mlp = GroupedInputMLPAdapter(config)
self.is_grouped_input_mode = False
self.grouped_cache_initialized = False
self._init_grouped_weights()
self._freeze_layers()
def _init_grouped_weights(self):
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
self.grouped_input_mlp.apply(_init_weights)
def _freeze_layers(self):
for param in self.model.embed_tokens.parameters():
param.requires_grad = False
for i, layer in enumerate(self.model.layers):
if i == 0:
for param in layer.parameters():
param.requires_grad = True
else:
for param in layer.parameters():
param.requires_grad = False
for param in self.model.norm.parameters():
param.requires_grad = False
for param in self.lm_head.parameters():
param.requires_grad = False
for param in self.grouped_input_mlp.parameters():
param.requires_grad = True
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in self.parameters())
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
f"({trainable_params/total_params*100:.2f}%)")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
grouped_inputs: Optional[torch.FloatTensor] = None, # New parameter for grouped inputs
is_prefill: Optional[bool] = None, # Flag to indicate prefill phase
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if grouped_inputs is not None and is_prefill:
self.is_grouped_input_mode = True
processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs)
inputs_embeds = processed_grouped_inputs
input_ids = None # Don't use input_ids when we have grouped inputs
batch_size, seq_len = inputs_embeds.shape[:2]
if position_ids is None:
device = inputs_embeds.device
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
self.grouped_cache_initialized = True
elif not is_prefill and self.is_grouped_input_mode:
pass
else:
self.is_grouped_input_mode = False
# Call parent forward
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs
)
return outputs
class GroupedDataset(Dataset):
def __init__(self, dataset_path: str, tokenizer, max_response_length: int = 512,
validation_split: float = 0.1, is_validation: bool = False,
chunk_size: int = 1000, max_samples: Optional[int] = None):
self.dataset_path = dataset_path
self.tokenizer = tokenizer
self.max_response_length = max_response_length
self.validation_split = validation_split
self.is_validation = is_validation
self.chunk_size = chunk_size
self.max_samples = max_samples
self._chunk_cache = {}
self._cache_size_limit = 3 # Keep max 3 chunks in memory
self._build_index()
def _build_index(self):
logger.info(f"Building index for {self.dataset_path}")
with open(self.dataset_path, 'rb') as f:
data = pickle.load(f)
valid_indices = []
for i, item in enumerate(data):
if not item.get("error", False):
valid_indices.append(i)
if self.max_samples and len(valid_indices) >= self.max_samples:
break
total_valid = len(valid_indices)
val_size = min(1000, int(self.validation_split * total_valid))
train_size = total_valid - val_size
if self.is_validation:
self.valid_indices = valid_indices[train_size:train_size + val_size]
self.total_samples = val_size
else:
self.valid_indices = valid_indices[:train_size]
self.total_samples = train_size
self._full_data = data
logger.info(f"{'Validation' if self.is_validation else 'Training'} dataset: {self.total_samples} samples")
def _get_chunk_id(self, idx):
return idx // self.chunk_size
def _load_chunk(self, chunk_id):
if chunk_id in self._chunk_cache:
return self._chunk_cache[chunk_id]
start_idx = chunk_id * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples)
chunk_data = {}
for i in range(start_idx, end_idx):
actual_idx = self.valid_indices[i]
chunk_data[i] = self._full_data[actual_idx]
if len(self._chunk_cache) >= self._cache_size_limit:
oldest_chunk = min(self._chunk_cache.keys())
del self._chunk_cache[oldest_chunk]
self._chunk_cache[chunk_id] = chunk_data
return chunk_data
def __len__(self):
return self.total_samples
def __getitem__(self, idx):
if idx >= self.total_samples:
raise IndexError(f"Index {idx} out of range for dataset of size {self.total_samples}")
chunk_id = self._get_chunk_id(idx)
chunk_data = self._load_chunk(chunk_id)
item = chunk_data[idx]
return self._process_item(item)
def _process_item(self, item):
grouped_embeds = item["inputs_embeds"]
if isinstance(grouped_embeds, torch.Tensor):
grouped_embeds = grouped_embeds.clone()
else:
grouped_embeds = torch.tensor(grouped_embeds)
if grouped_embeds.dtype != torch.float32:
grouped_embeds = grouped_embeds.float()
response = item["response"]
response_tokens = self.tokenizer(
response,
max_length=self.max_response_length,
truncation=True,
padding=False,
return_tensors="pt"
)
response_input_ids = response_tokens["input_ids"].squeeze(0)
return {
"grouped_inputs": grouped_embeds,
"response_input_ids": response_input_ids,
"response_text": response,
"input_text": item["input_text"],
}
def cleanup(self):
self._chunk_cache.clear()
if hasattr(self, '_full_data'):
del self._full_data
def collate_fn(batch, tokenizer, pad_token_id=None):
if pad_token_id is None:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
grouped_inputs = [item["grouped_inputs"] for item in batch]
response_input_ids = [item["response_input_ids"] for item in batch]
max_grouped_len = max(gi.shape[0] for gi in grouped_inputs)
batch_size = len(grouped_inputs)
hidden_size = grouped_inputs[0].shape[-1]
padded_grouped_inputs = torch.zeros(batch_size, max_grouped_len, hidden_size)
grouped_attention_mask = torch.zeros(batch_size, max_grouped_len, dtype=torch.long)
for i, gi in enumerate(grouped_inputs):
seq_len = gi.shape[0]
padded_grouped_inputs[i, :seq_len] = gi
grouped_attention_mask[i, :seq_len] = 1
max_response_len = max(len(rid) for rid in response_input_ids)
padded_response_ids = torch.full((batch_size, max_response_len), pad_token_id, dtype=torch.long)
for i, rid in enumerate(response_input_ids):
padded_response_ids[i, :len(rid)] = rid
return {
"grouped_inputs": padded_grouped_inputs,
"grouped_attention_mask": grouped_attention_mask,
"response_input_ids": padded_response_ids,
"response_texts": [item["response_text"] for item in batch],
"input_texts": [item["input_text"] for item in batch],
}
class TrainingState:
def __init__(self, output_dir: Path):
self.output_dir = output_dir
self.state_file = output_dir / "training_state.json"
def save_state(self, epoch: int, global_step: int, best_val_loss: float,
optimizer_state: Dict, scheduler_state: Dict):
"""Save training state."""
state = {
"epoch": epoch,
"global_step": global_step,
"best_val_loss": best_val_loss,
"optimizer_state": optimizer_state,
"scheduler_state": scheduler_state,
"completed_epochs": epoch
}
with open(self.state_file, 'w') as f:
json.dump(state, f, indent=2, default=str) # default=str for handling tensor types
logger.info(f"Saved training state at epoch {epoch}, step {global_step}")
def load_state(self):
if not self.state_file.exists():
return None
try:
with open(self.state_file, 'r') as f:
state = json.load(f)
logger.info(f"Loaded training state from epoch {state['epoch']}, step {state['global_step']}")
return state
except Exception as e:
logger.warning(f"Failed to load training state: {e}")
return None
def get_latest_checkpoint(self):
state = self.load_state()
if state is None:
return None
epoch = state["completed_epochs"]
checkpoint_path = self.output_dir / f"epoch_{epoch}"
if checkpoint_path.exists():
return checkpoint_path, state
else:
logger.warning(f"Checkpoint for epoch {epoch} not found")
return None
class GroupedTrainer:
def __init__(
self,
model_name: str = "Qwen/Qwen3-0.6B",
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
output_dir: str = "./grouped_qwen3_checkpoint",
batch_size: int = 4,
learning_rate: float = 5e-4,
num_epochs: int = 3,
warmup_steps: int = 100,
max_grad_norm: float = 1.0,
save_steps: int = 500,
eval_steps: int = 500,
logging_steps: int = 50,
resume_training: bool = True,
debug: bool = False,
chunk_size: int = 1000, # Chunk size for streaming
max_samples: Optional[int] = None, # Limit dataset size for testing
):
self.model_name = model_name
self.dataset_path = dataset_path
self.output_dir = Path(output_dir)
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.warmup_steps = warmup_steps
self.max_grad_norm = max_grad_norm
self.save_steps = save_steps
self.eval_steps = eval_steps
self.logging_steps = logging_steps
self.resume_training = resume_training
self.debug = debug
self.chunk_size = chunk_size
self.max_samples = max_samples
if self.debug:
logger.setLevel(logging.DEBUG)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = None
self.model = None
self.train_dataset = None
self.val_dataset = None
self.training_state = TrainingState(self.output_dir)
def load_model_and_tokenizer(self):
logger.info(f"Loading tokenizer and model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
config = Qwen3Config.from_pretrained(self.model_name)
self.model = CustomQwen3ForCausalLM.from_pretrained(
self.model_name,
config=config,
torch_dtype=torch.float32, # Use float32 for training
attn_implementation="eager"
).to(self.device)
logger.info(f"Model loaded on {self.device}")
def load_dataset(self, chunk_size: int = 1000, max_samples: Optional[int] = None):
logger.info(f"Loading streaming dataset from {self.dataset_path}")
# Create streaming datasets
self.train_dataset = GroupedDataset(
dataset_path=self.dataset_path,
tokenizer=self.tokenizer,
is_validation=False,
chunk_size=chunk_size,
max_samples=max_samples
)
self.val_dataset = GroupedDataset(
dataset_path=self.dataset_path,
tokenizer=self.tokenizer,
is_validation=True,
chunk_size=chunk_size,
max_samples=max_samples
)
logger.info(f"Train samples: {len(self.train_dataset)}")
logger.info(f"Val samples: {len(self.val_dataset)}")
# Log memory usage
if torch.cuda.is_available():
torch.cuda.empty_cache()
memory_used = torch.cuda.memory_allocated() / 1024**3
logger.info(f"GPU memory after dataset loading: {memory_used:.2f} GB")
def cleanup_datasets(self):
if hasattr(self.train_dataset, 'cleanup'):
self.train_dataset.cleanup()
if hasattr(self.val_dataset, 'cleanup'):
self.val_dataset.cleanup()
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_checkpoint(self, checkpoint_path: Path):
logger.info(f"Loading checkpoint from {checkpoint_path}")
model_path = checkpoint_path / "pytorch_model.bin"
if not model_path.exists():
model_path = checkpoint_path / "model.safetensors"
if model_path.exists():
state_dict = torch.load(model_path, map_location=self.device)
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys when loading checkpoint: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
logger.info("Model checkpoint loaded successfully")
else:
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
def compute_loss(self, batch, outputs):
logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
target_ids = batch["response_input_ids"].to(self.device) # [batch_size, target_len]
logger.debug(f"Logits shape: {logits.shape}, Target shape: {target_ids.shape}")
batch_size = target_ids.shape[0]
if target_ids.shape[1] > 1:
labels = target_ids.clone()
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
labels[labels == pad_token_id] = -100
seq_len = min(logits.shape[1], labels.shape[1])
logits_truncated = logits[:, :seq_len, :] # [batch_size, seq_len, vocab_size]
labels_truncated = labels[:, :seq_len] # [batch_size, seq_len]
logger.debug(f"After truncation - Logits: {logits_truncated.shape}, Labels: {labels_truncated.shape}")
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
logits_truncated.reshape(-1, logits_truncated.size(-1)),
labels_truncated.reshape(-1)
)
else:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.size(-1)), target_ids.view(-1))
return loss
def training_step(self, batch, step):
self.model.train()
if step < 5 and torch.cuda.is_available():
torch.cuda.empty_cache()
memory_before = torch.cuda.memory_allocated() / 1024**3
grouped_inputs = batch["grouped_inputs"].to(self.device)
grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
response_input_ids = batch["response_input_ids"].to(self.device)
batch_size = grouped_inputs.shape[0]
grouped_seq_len = grouped_inputs.shape[1]
response_seq_len = response_input_ids.shape[1]
if self.debug:
logger.debug(f"Batch sizes - grouped: {grouped_inputs.shape}, response: {response_input_ids.shape}")
grouped_outputs = self.model(
grouped_inputs=grouped_inputs,
attention_mask=grouped_attention_mask,
is_prefill=True,
use_cache=True,
return_dict=True
)
if response_seq_len > 1:
response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
response_outputs = self.model(
input_ids=response_input_ids[:, :-1], # All but last token as input
attention_mask=response_attention_mask[:, :-1],
past_key_values=grouped_outputs.past_key_values,
use_cache=False,
return_dict=True
)
logits = response_outputs.logits
labels = response_input_ids[:, 1:] # All but first token as targets
labels = labels.clone()
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
labels[labels == pad_token_id] = -100
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
else:
loss = torch.tensor(0.0, requires_grad=True, device=self.device)
if step < 5 and torch.cuda.is_available():
memory_after = torch.cuda.memory_allocated() / 1024**3
memory_peak = torch.cuda.max_memory_allocated() / 1024**3
logger.info(f"Step {step} Memory: {memory_before:.2f}GB β†’ {memory_after:.2f}GB (Peak: {memory_peak:.2f}GB)")
if memory_peak > 20.0: # 20GB threshold for L4
logger.warning("High memory usage detected! Consider reducing batch_size")
class MockOutputs:
def __init__(self, loss, logits):
self.loss = loss
self.logits = logits
outputs = MockOutputs(loss, response_outputs.logits if 'response_outputs' in locals() else grouped_outputs.logits)
return loss, outputs
def validation_step(self, batch):
"""Single validation step."""
self.model.eval()
with torch.no_grad():
grouped_inputs = batch["grouped_inputs"].to(self.device)
grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
response_input_ids = batch["response_input_ids"].to(self.device)
grouped_outputs = self.model(
grouped_inputs=grouped_inputs,
attention_mask=grouped_attention_mask,
is_prefill=True,
use_cache=True,
return_dict=True
)
if response_input_ids.shape[1] > 1:
response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
response_outputs = self.model(
input_ids=response_input_ids[:, :-1],
attention_mask=response_attention_mask[:, :-1],
past_key_values=grouped_outputs.past_key_values,
use_cache=False,
return_dict=True
)
logits = response_outputs.logits
labels = response_input_ids[:, 1:]
labels = labels.clone()
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
labels[labels == pad_token_id] = -100
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
else:
loss = torch.tensor(0.0, device=self.device)
return loss.item()
def save_epoch_checkpoint(self, epoch: int, global_step: int, is_best: bool = False):
checkpoint_name = f"epoch_{epoch}"
if is_best:
checkpoint_name += "_best"
checkpoint_dir = self.output_dir / checkpoint_name
checkpoint_dir.mkdir(exist_ok=True)
torch.save(self.model.state_dict(), checkpoint_dir / "pytorch_model.bin")
self.model.config.save_pretrained(checkpoint_dir)
self.tokenizer.save_pretrained(checkpoint_dir)
metadata = {
"epoch": epoch,
"global_step": global_step,
"model_name": self.model_name,
"learning_rate": self.learning_rate,
"batch_size": self.batch_size,
"is_best": is_best,
"model_class": "CustomQwen3ForCausalLM"
}
with open(checkpoint_dir / "epoch_metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved epoch checkpoint: {checkpoint_dir}")
return checkpoint_dir
def train(self):
logger.info("Starting training...")
train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
num_workers=0 # Avoid multiprocessing issues with custom collate_fn
)
val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
num_workers=0
)
optimizer = AdamW(
[p for p in self.model.parameters() if p.requires_grad],
lr=self.learning_rate,
weight_decay=0.01
)
total_steps = len(train_loader) * self.num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=total_steps
)
start_epoch = 0
global_step = 0
best_val_loss = float('inf')
if self.resume_training:
checkpoint_info = self.training_state.get_latest_checkpoint()
if checkpoint_info is not None:
checkpoint_path, state = checkpoint_info
self.load_checkpoint(checkpoint_path)
start_epoch = state["completed_epochs"]
global_step = state["global_step"]
best_val_loss = state["best_val_loss"]
if "optimizer_state" in state and state["optimizer_state"]:
try:
optimizer.load_state_dict(state["optimizer_state"])
except Exception as e:
logger.warning(f"Failed to load optimizer state: {e}")
if "scheduler_state" in state and state["scheduler_state"]:
try:
scheduler.load_state_dict(state["scheduler_state"])
except Exception as e:
logger.warning(f"Failed to load scheduler state: {e}")
logger.info(f"Resumed training from epoch {start_epoch + 1}")
for epoch in range(start_epoch, self.num_epochs):
logger.info(f"Epoch {epoch + 1}/{self.num_epochs}")
epoch_train_loss = 0
train_steps = 0
progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
for batch_idx, batch in enumerate(progress_bar):
try:
loss, outputs = self.training_step(batch, global_step)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
epoch_train_loss += loss.item()
train_steps += 1
global_step += 1
progress_bar.set_postfix({
'loss': f'{loss.item():.4f}',
'lr': f'{scheduler.get_last_lr()[0]:.2e}'
})
if global_step % self.logging_steps == 0:
avg_loss = epoch_train_loss / train_steps
logger.info(f"Step {global_step}: train_loss={avg_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}")
if global_step % self.eval_steps == 0:
val_loss = self.validate(val_loader)
logger.info(f"Step {global_step}: val_loss={val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_checkpoint = self.save_epoch_checkpoint(epoch, global_step, is_best=True)
logger.info(f"New best validation loss: {val_loss:.4f}")
except Exception as e:
logger.error(f"Error in training step {global_step}: {e}")
continue
val_loss = self.validate(val_loader)
avg_train_loss = epoch_train_loss / train_steps if train_steps > 0 else 0
logger.info(f"Epoch {epoch + 1} completed:")
logger.info(f" Average train loss: {avg_train_loss:.4f}")
logger.info(f" Validation loss: {val_loss:.4f}")
is_best = val_loss < best_val_loss
if is_best:
best_val_loss = val_loss
checkpoint_dir = self.save_epoch_checkpoint(epoch, global_step, is_best=is_best)
self.training_state.save_state(
epoch=epoch,
global_step=global_step,
best_val_loss=best_val_loss,
optimizer_state=optimizer.state_dict(),
scheduler_state=scheduler.state_dict()
)
logger.info(f"Epoch {epoch + 1} checkpoint and state saved")
logger.info(f"Training completed! Best validation loss: {best_val_loss:.4f}")
final_checkpoint = self.save_epoch_checkpoint(self.num_epochs - 1, global_step, is_best=False)
logger.info(f"Final checkpoint saved: {final_checkpoint}")
def validate(self, val_loader):
self.model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"):
try:
loss = self.validation_step(batch)
total_loss += loss
num_batches += 1
except Exception as e:
logger.warning(f"Error in validation step: {e}")
continue
avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
self.model.train() # Set back to training mode
return avg_loss
def run(self):
try:
self.load_model_and_tokenizer()
self.load_dataset(
chunk_size=self.chunk_size,
max_samples=self.max_samples
)
self.train()
logger.info("Training pipeline completed successfully!")
self.cleanup_datasets()
except Exception as e:
logger.error(f"Training pipeline failed: {e}")
import traceback
logger.error(traceback.format_exc())
try:
self.cleanup_datasets()
except:
pass
raise
def load_trained_model(checkpoint_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
logger.info(f"Loading trained model from {checkpoint_path}")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
config = Qwen3Config.from_pretrained(checkpoint_path)
model = CustomQwen3ForCausalLM(config)
model_path = Path(checkpoint_path) / "pytorch_model.bin"
if not model_path.exists():
model_path = Path(checkpoint_path) / "model.safetensors"
if not model_path.exists():
raise FileNotFoundError(f"No model weights found in {checkpoint_path}")
state_dict = torch.load(model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys when loading model: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys when loading model: {unexpected_keys}")
model = model.eval().to(torch.float32)
return model, tokenizer
def generate_with_grouped_input(
model,
tokenizer,
grouped_input: torch.Tensor,
max_length: int = 512,
temperature: float = 0.7,
do_sample: bool = True
):
device = model.device
model_dtype = next(model.parameters()).dtype
grouped_input = grouped_input.to(device=device, dtype=model_dtype)
if grouped_input.ndim == 2:
grouped_input = grouped_input.unsqueeze(0) # Add batch dimension
logger.debug(f"Grouped input shape: {grouped_input.shape}, dtype: {grouped_input.dtype}")
logger.debug(f"Model dtype: {model_dtype}, device: {device}")
with torch.no_grad():
try:
outputs = model(
grouped_inputs=grouped_input,
is_prefill=True,
use_cache=True,
return_dict=True
)
except Exception as e:
logger.error(f"Error in prefill phase: {e}")
raise
if hasattr(outputs, 'logits') and outputs.logits is not None:
next_token_logits = outputs.logits[:, -1, :]
elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
last_hidden_state = outputs.hidden_states[-1] if isinstance(outputs.hidden_states, (list, tuple)) else outputs.hidden_states
next_token_logits = model.lm_head(last_hidden_state[:, -1, :])
else:
raise RuntimeError("Could not extract logits from model output")
if do_sample:
next_token_logits = next_token_logits / temperature
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated_ids = next_token
past_key_values = getattr(outputs, 'past_key_values', None)
for step in range(max_length - 1):
with torch.no_grad():
try:
outputs = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
except Exception as e:
logger.error(f"Error in generation step {step}: {e}")
break
if hasattr(outputs, 'logits'):
next_token_logits = outputs.logits[:, -1, :]
else:
logger.warning("No logits in generation output, stopping generation")
break
if do_sample:
next_token_logits = next_token_logits / temperature
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
past_key_values = getattr(outputs, 'past_key_values', None)
if next_token.item() == tokenizer.eos_token_id:
break
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text
def main():
config = {
"model_name": "Qwen/Qwen3-0.6B",
"dataset_path": "./processed_qwen3_dataset/processed_dataset.pkl",
"output_dir": "./grouped_qwen3_checkpoint",
"batch_size": 12, # Optimized for L4 24GB VRAM
"learning_rate": 5e-4,
"num_epochs": 3,
"warmup_steps": 500, # Increased for larger batch
"max_grad_norm": 1.0,
"save_steps": 1000, # Less frequent saves due to larger batches
"eval_steps": 1000, # Less frequent evaluation
"logging_steps": 100,
"resume_training": True,
"debug": False, # Disable debug for performance
# Streaming parameters
"chunk_size": 2000, # Load 2000 samples per chunk
"max_samples": None, # Use full dataset (set to smaller number for testing)
}
logger.info("="*60)
logger.info("GROUPED QWEN3 TRAINING CONFIGURATION (STREAMING)")
logger.info("="*60)
for key, value in config.items():
logger.info(f"{key}: {value}")
logger.info("="*60)
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name()}")
logger.info(f"VRAM Total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
import psutil
ram_usage = psutil.virtual_memory()
logger.info(f"System RAM: {ram_usage.used / 1024**3:.1f} GB / {ram_usage.total / 1024**3:.1f} GB ({ram_usage.percent:.1f}%)")
trainer = GroupedTrainer(**config)
trainer.run()
def inference_by_id(sample_id: int, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
"""Run inference on a specific sample ID from the dataset."""
logger.info(f"Running inference on sample ID: {sample_id}")
try:
model, tokenizer = load_trained_model(checkpoint_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
logger.info(f"Model loaded from {checkpoint_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
return None
try:
logger.info(f"Loading sample {sample_id} from dataset...")
with open(dataset_path, 'rb') as f:
processed_data = pickle.load(f)
if sample_id >= len(processed_data):
logger.error(f"Sample ID {sample_id} is out of range. Dataset has {len(processed_data)} samples.")
return None
sample = processed_data[sample_id]
if sample.get("error", False):
logger.error(f"Sample {sample_id} has errors and cannot be used for inference.")
return None
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
return None
grouped_embeds_raw = sample["inputs_embeds"]
if isinstance(grouped_embeds_raw, torch.Tensor):
grouped_input = grouped_embeds_raw.detach().clone().float()
else:
grouped_input = torch.tensor(grouped_embeds_raw, dtype=torch.float32)
original_input = sample["input_text"]
expected_response = sample["response"]
print("\n" + "="*80)
print(f"INFERENCE ON SAMPLE ID: {sample_id}")
print("="*80)
print(f"πŸ“ ORIGINAL REQUEST:")
print(f"{original_input}")
print("\n" + "-"*80)
print(f"🎯 EXPECTED RESPONSE:")
print(f"{expected_response}")
print("\n" + "-"*80)
print(f"πŸ€– MODEL GENERATED RESPONSE:")
try:
generated_text = generate_with_grouped_input(
model=model,
tokenizer=tokenizer,
grouped_input=grouped_input,
max_length=max_length,
temperature=temperature,
do_sample=do_sample
)
print(f"{generated_text}")
print("\n" + "="*80)
expected_words = expected_response.split()
generated_words = generated_text.split()
print(f"πŸ“Š METRICS:")
print(f"Expected length: {len(expected_words)} words")
print(f"Generated length: {len(generated_words)} words")
print(f"Temperature: {temperature}")
print(f"Max length: {max_length}")
print("="*80)
return {
"sample_id": sample_id,
"original_input": original_input,
"expected_response": expected_response,
"generated_response": generated_text,
"expected_length": len(expected_words),
"generated_length": len(generated_words)
}
except Exception as e:
logger.error(f"Failed to generate response: {e}")
print(f"❌ GENERATION FAILED: {e}")
print("="*80)
return None
def batch_inference(sample_ids: List[int], checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
"""Run inference on multiple sample IDs."""
logger.info(f"Running batch inference on {len(sample_ids)} samples")
results = []
for sample_id in sample_ids:
result = inference_by_id(
sample_id=sample_id,
checkpoint_path=checkpoint_path,
dataset_path=dataset_path,
max_length=max_length,
temperature=temperature,
do_sample=do_sample
)
if result:
results.append(result)
print("\n" + "πŸ”„ " + "-"*78 + " πŸ”„\n") # Separator between samples
print("\n" + "="*80)
print(f"πŸ“‹ BATCH INFERENCE SUMMARY")
print("="*80)
print(f"Total samples processed: {len(results)}")
if results:
avg_expected_len = sum(r["expected_length"] for r in results) / len(results)
avg_generated_len = sum(r["generated_length"] for r in results) / len(results)
print(f"Average expected length: {avg_expected_len:.1f} words")
print(f"Average generated length: {avg_generated_len:.1f} words")
print("="*80)
return results
def random_inference(num_samples: int = 3, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
"""Run inference on random samples from the dataset."""
import random
try:
with open(dataset_path, 'rb') as f:
processed_data = pickle.load(f)
# Find valid samples
valid_indices = [i for i, item in enumerate(processed_data) if not item.get("error", False)]
if len(valid_indices) < num_samples:
logger.warning(f"Only {len(valid_indices)} valid samples available, using all of them")
num_samples = len(valid_indices)
# Select random samples
random_ids = random.sample(valid_indices, num_samples)
logger.info(f"Selected random sample IDs: {random_ids}")
# Run batch inference
return batch_inference(
sample_ids=random_ids,
checkpoint_path=checkpoint_path,
dataset_path=dataset_path,
max_length=max_length,
temperature=temperature,
do_sample=do_sample
)
except Exception as e:
logger.error(f"Failed to load dataset for random sampling: {e}")
return None
def interactive_inference(checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl"):
"""Interactive inference mode where user can input sample IDs."""
print("\n" + "="*80)
print("πŸ€– INTERACTIVE INFERENCE MODE")
print("="*80)
print("Commands:")
print(" <number> - Run inference on sample ID")
print(" random <n> - Run inference on n random samples (default: 3)")
print(" batch <ids> - Run inference on multiple IDs (e.g., 'batch 1,5,10')")
print(" quit - Exit")
print("="*80)
while True:
try:
user_input = input("\nπŸ” Enter command: ").strip().lower()
if user_input in ['quit', 'exit', 'q']:
print("πŸ‘‹ Goodbye!")
break
elif user_input.startswith('random'):
parts = user_input.split()
num_samples = int(parts[1]) if len(parts) > 1 else 3
random_inference(num_samples=num_samples, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
elif user_input.startswith('batch'):
parts = user_input.split(maxsplit=1)
if len(parts) > 1:
ids_str = parts[1]
sample_ids = [int(x.strip()) for x in ids_str.split(',')]
batch_inference(sample_ids=sample_ids, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
else:
print("❌ Please provide sample IDs: batch 1,5,10")
elif user_input.isdigit():
sample_id = int(user_input)
inference_by_id(sample_id=sample_id, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
else:
print("❌ Invalid command. Try a number, 'random', 'batch', or 'quit'")
except ValueError:
print("❌ Invalid input. Please enter a valid number or command.")
except KeyboardInterrupt:
print("\nπŸ‘‹ Goodbye!")
break
except Exception as e:
print(f"❌ Error: {e}")
def test_inference():
logger.info("Running inference tests...")
test_ids = [0, 1, 2, 100, 500] # Mix of early and later samples
print("\nπŸ§ͺ TESTING INFERENCE ON PREDEFINED SAMPLES")
results = batch_inference(
sample_ids=test_ids,
max_length=300,
temperature=0.7,
do_sample=True
)
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Grouped Qwen3 Training and Inference")
parser.add_argument("--mode", choices=["train", "test", "inference", "interactive", "random"],
default="train", help="Mode to run")
parser.add_argument("--sample_id", type=int, help="Sample ID for inference mode")
parser.add_argument("--sample_ids", type=str, help="Comma-separated sample IDs for batch inference")
parser.add_argument("--num_samples", type=int, default=3, help="Number of random samples for random mode")
parser.add_argument("--checkpoint", type=str, default="./grouped_qwen3_checkpoint/epoch_2_best",
help="Path to model checkpoint")
parser.add_argument("--dataset", type=str, default="./processed_qwen3_dataset/processed_dataset.pkl",
help="Path to dataset")
parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length")
parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature")
args = parser.parse_args()
if args.mode == "train":
main()
elif args.mode == "test":
test_inference()
elif args.mode == "inference":
if args.sample_id is not None:
inference_by_id(
sample_id=args.sample_id,
checkpoint_path=args.checkpoint,
dataset_path=args.dataset,
max_length=args.max_length,
temperature=args.temperature
)
elif args.sample_ids is not None:
sample_ids = [int(x.strip()) for x in args.sample_ids.split(',')]
batch_inference(
sample_ids=sample_ids,
checkpoint_path=args.checkpoint,
dataset_path=args.dataset,
max_length=args.max_length,
temperature=args.temperature
)
else:
print("❌ Please provide --sample_id or --sample_ids for inference mode")
elif args.mode == "interactive":
interactive_inference(
checkpoint_path=args.checkpoint,
dataset_path=args.dataset
)
elif args.mode == "random":
random_inference(
num_samples=args.num_samples,
checkpoint_path=args.checkpoint,
dataset_path=args.dataset,
max_length=args.max_length,
temperature=args.temperature
)