| | import argparse |
| | import math |
| | import os |
| | import sys |
| | import json |
| | import jsonlines |
| | import copy |
| | from typing import List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader, Dataset, random_split |
| | from torch.cuda.amp import autocast, GradScaler |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| | from tqdm import tqdm |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | |
| | |
| | from lightbulb_custom import ( |
| | RotaryPositionalEncoding, |
| | MultiHeadAttention, |
| | MoE, |
| | TransformerBlock, |
| | Transformer, |
| | InfoNCE_Loss, |
| | CovarianceRegularization, |
| | DynamicsPerformanceLoss, |
| | ThoughtConsistencyLoss, |
| | PolicyValueJointLoss, |
| | ActionDiversityReward, |
| | ExpectedThoughtValueLoss, |
| | ExplorationRegularization, |
| | KL_DivergenceLoss, |
| | ActionEncoder, |
| | RepresentationNetwork, |
| | DynamicsNetwork, |
| | PredictionNetwork, |
| | ThoughtNode, |
| | MCTS, |
| | State |
| | ) |
| |
|
| | |
| | |
| | |
| | class CustomDataset(Dataset): |
| | def __init__(self, inputs, labels): |
| | self.inputs = inputs |
| | self.labels = labels |
| |
|
| | def __len__(self): |
| | return len(self.inputs) |
| |
|
| | def __getitem__(self, idx): |
| | return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]} |
| |
|
| | |
| | |
| | |
| | def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None): |
| | dataset = load_dataset(dataset_name, config) |
| | if queries: |
| | def filter_func(examples): |
| | return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]] |
| | dataset = dataset.filter(filter_func, batched=True) |
| | return dataset |
| |
|
| | def load_custom_data_from_files(file_paths): |
| | custom_data = [] |
| | for file_path in file_paths: |
| | if file_path.endswith('.json'): |
| | with open(file_path, 'r') as f: |
| | data = json.load(f) |
| | if isinstance(data, list): |
| | custom_data.extend(data) |
| | else: |
| | custom_data.append(data) |
| | elif file_path.endswith('.jsonl'): |
| | with jsonlines.open(file_path) as reader: |
| | custom_data.extend(reader) |
| | return custom_data |
| |
|
| | def preprocess_custom_data(data_list): |
| | processed_data = [] |
| | for item in data_list: |
| | |
| | if isinstance(item, str): |
| | try: |
| | item = json.loads(item) |
| | except json.JSONDecodeError: |
| | print(f"Failed to parse JSON: {item[:100]}...") |
| | continue |
| |
|
| | |
| | query = item.get('query', '') |
| | content = item.get('content', '') |
| | if content == "RAG response generation failed.": |
| | content = "" |
| |
|
| | |
| | combined_text = f"Query: {query} Content: {content}" |
| |
|
| | |
| | episode_reward = item.get('episode_reward', 0) |
| | loss = item.get('loss', 0) |
| | cosine_similarity = item.get('cosine_similarity', 0) |
| | rag_performance = item.get('rag_performance', 0) |
| | ranking_model_performance = item.get('ranking_model_performance', 0) |
| |
|
| | |
| | processed_item = { |
| | 'text': combined_text, |
| | 'episode_reward': episode_reward, |
| | 'loss': loss, |
| | 'cosine_similarity': cosine_similarity, |
| | 'rag_performance': rag_performance, |
| | 'ranking_model_performance': ranking_model_performance |
| | } |
| |
|
| | processed_data.append(processed_item) |
| |
|
| | return processed_data |
| |
|
| | def load_custom_data(args, tokenizer, custom_data): |
| | |
| | processed_data = preprocess_custom_data(custom_data) |
| |
|
| | |
| | class CustomDatasetProcessed(torch.utils.data.Dataset): |
| | def __init__(self, data, tokenizer, max_length): |
| | self.data = data |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | encoded = self.tokenizer.encode_plus( |
| | item['text'], |
| | max_length=self.max_length, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors='pt' |
| | ) |
| | return { |
| | 'input_ids': encoded['input_ids'].squeeze(), |
| | 'attention_mask': encoded['attention_mask'].squeeze(), |
| | 'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float), |
| | 'loss': torch.tensor(item['loss'], dtype=torch.float), |
| | 'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float), |
| | 'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float), |
| | 'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float) |
| | } |
| |
|
| | |
| | dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length) |
| |
|
| | |
| | train_size = int(0.8 * len(dataset)) |
| | eval_size = len(dataset) - train_size |
| | train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size]) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=4 |
| | ) |
| | eval_loader = DataLoader( |
| | eval_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=4 |
| | ) |
| |
|
| | return train_loader, eval_loader |
| |
|
| | def prepare_data(tokenizer, dataset, max_length, batch_size): |
| | |
| | tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) |
| | tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length) |
| |
|
| | |
| | custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"]) |
| |
|
| | |
| | train_size = int(0.9 * len(custom_dataset)) |
| | val_size = len(custom_dataset) - train_size |
| | train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size]) |
| |
|
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | shuffle=True, |
| | batch_size=batch_size, |
| | num_workers=4, |
| | pin_memory=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | shuffle=False, |
| | batch_size=batch_size, |
| | num_workers=4, |
| | pin_memory=True |
| | ) |
| |
|
| | return train_loader, val_loader |
| |
|
| | |
| | |
| | |
| |
|
| | def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch): |
| | """ |
| | Save all models to the specified directory. |
| | Args: |
| | transformer_model (nn.Module): Transformer model. |
| | representation_network (nn.Module): Representation network. |
| | dynamics_network (nn.Module): Dynamics network. |
| | prediction_network (nn.Module): Prediction network. |
| | action_encoder (nn.Module): Action encoder. |
| | save_dir (str): Directory to save the models. |
| | epoch (int): Current epoch number. |
| | """ |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt')) |
| | torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt')) |
| | torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt')) |
| | torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt')) |
| | torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt')) |
| |
|
| | print(f"All models saved for epoch {epoch}.") |
| |
|
| | def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
| | representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
| | representation_network.train() |
| | dynamics_network.train() |
| | prediction_network.train() |
| | action_encoder.train() |
| | ppo_agent.policy_network.train() |
| |
|
| | total_loss = 0.0 |
| | optimizer.zero_grad() |
| | print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
| |
|
| | for i, batch in enumerate(train_loader): |
| | print(f"Processing batch {i+1}/{len(train_loader)}...") |
| |
|
| | |
| | src_batch = batch['input_ids'].to(device) |
| | tgt_batch = batch['labels'].to(device) |
| |
|
| | with torch.cuda.amp.autocast(): |
| | print("Forward pass through Transformer (frozen)...") |
| | with torch.no_grad(): |
| | transformer_output = model_transformer(src_batch, tgt_batch[:, :-1]) |
| |
|
| | |
| | state_representation = representation_network(transformer_output) |
| |
|
| | |
| | true_actions = tgt_batch[:, :-1] |
| | print(f"True actions shape: {true_actions.shape}") |
| | action_sequences = true_actions |
| |
|
| | |
| | action_embeddings = action_encoder(action_sequences) |
| | print(f"Action embeddings shape: {action_embeddings.shape}") |
| |
|
| | |
| | predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
| | print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}") |
| |
|
| | |
| | policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
| |
|
| | |
| | true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
| | true_value = torch.zeros_like(value_estimates).to(device) |
| |
|
| | |
| | ppo_loss = ppo_agent.compute_loss( |
| | state_representation, |
| | torch.zeros_like(true_actions, dtype=torch.float32).to(device), |
| | true_actions, |
| | torch.zeros_like(value_estimates, dtype=torch.float32).to(device), |
| | torch.zeros_like(value_estimates, dtype=torch.float32).to(device) |
| | ) |
| |
|
| | info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim), |
| | F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True)) |
| |
|
| | covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
| | dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
| |
|
| | perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
| | thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
| |
|
| | pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
| | action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
| |
|
| | mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
| | etv = ExpectedThoughtValueLoss()(mcts_best_values) |
| |
|
| | visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
| | exploration = ExplorationRegularization()(visit_counts) |
| |
|
| | old_policy = F.softmax(policy_logits.detach(), dim=-1) |
| | new_policy = F.softmax(policy_logits, dim=-1) |
| | kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
| |
|
| | |
| | loss = ( |
| | ppo_loss + |
| | info_nce + |
| | covariance + |
| | dynamics_loss + |
| | thought_loss + |
| | pv_loss + |
| | action_diversity + |
| | etv + |
| | exploration + |
| | kl_loss |
| | ) |
| | loss = loss / args.accumulation_steps |
| |
|
| | print("Backward pass...") |
| | scaler.scale(loss).backward() |
| |
|
| | if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| | print("Gradient clipping...") |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_( |
| | [param for group in optimizer.param_groups for param in group['params']], |
| | args.max_grad_norm |
| | ) |
| |
|
| | print("Optimizer step...") |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | print("Zeroing gradients...") |
| | optimizer.zero_grad() |
| |
|
| | print("Updating learning rate...") |
| | scheduler.step() |
| |
|
| | total_loss += loss.item() * args.accumulation_steps |
| |
|
| | |
| | print(f"Batch {i+1} completed. Losses:") |
| | print(f" PPO Loss: {ppo_loss.item():.4f}") |
| | print(f" InfoNCE Loss: {info_nce.item():.4f}") |
| | print(f" Covariance Loss: {covariance.item():.4f}") |
| | print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
| | print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
| | print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
| | print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
| | print(f" Expected Thought Value Loss: {etv.item():.4f}") |
| | print(f" Exploration Loss: {exploration.item():.4f}") |
| | print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
| | print(f" Total Loss: {loss.item():.4f}") |
| |
|
| | avg_loss = total_loss / len(train_loader) |
| | print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
| | return avg_loss |
| |
|
| | def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0): |
| | teacher.eval() |
| | student.train() |
| | total_loss = 0 |
| |
|
| | for batch in tqdm(data_loader, desc="Training"): |
| | inputs = batch["input_ids"].to(device) |
| | labels = batch["labels"].to(device) |
| |
|
| | with autocast(): |
| | with torch.no_grad(): |
| | teacher_outputs = teacher(inputs).logits |
| | teacher_logits = teacher_outputs / temperature |
| |
|
| | student_outputs = student(inputs).logits |
| | student_logits = student_outputs / temperature |
| |
|
| | |
| | loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) |
| | loss = loss * (temperature ** 2) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad() |
| |
|
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(data_loader) |
| | return avg_loss |
| |
|
| | def validate(teacher, student, data_loader, criterion, temperature=2.0): |
| | teacher.eval() |
| | student.eval() |
| | total_loss = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in tqdm(data_loader, desc="Validation"): |
| | inputs = batch["input_ids"].to(device) |
| | labels = batch["labels"].to(device) |
| |
|
| | teacher_outputs = teacher(inputs).logits |
| | teacher_logits = teacher_outputs / temperature |
| |
|
| | student_outputs = student(inputs).logits |
| | student_logits = student_outputs / temperature |
| |
|
| | loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1)) |
| | loss = loss * (temperature ** 2) |
| |
|
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(data_loader) |
| | return avg_loss |
| |
|
| | def save_checkpoint(state, save_dir, epoch): |
| | os.makedirs(save_dir, exist_ok=True) |
| | checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt') |
| | torch.save(state, checkpoint_path) |
| | print(f"Checkpoint saved at {checkpoint_path}") |
| |
|
| | |
| | |
| | |
| |
|
| | def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414): |
| | """ |
| | Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search. |
| | Args: |
| | query (str): The input query or prompt. |
| | world_model_components (tuple): Tuple containing the model components. |
| | root_thought_node (ThoughtNode): The root node of the Tree of Thought. |
| | tokenizer (transformers.PreTrainedTokenizer): The tokenizer used. |
| | max_length (int): Maximum length for the generated sequence. |
| | inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought') |
| | beam_size (int): Size of the beam for beam search |
| | n_tokens_predict (int): Number of tokens to predict at each step |
| | mcts_iterations (int): Number of MCTS iterations |
| | exploration_constant (float): Exploration constant for MCTS |
| | Returns: |
| | List[str] or str: The sequence of actions (thoughts) selected or generated text. |
| | """ |
| | if inference_mode != 'world_model': |
| | print("Inference mode other than 'world_model' not implemented yet.") |
| | return "" |
| |
|
| | representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
| |
|
| | |
| | input_ids = tokenizer.encode(query, return_tensors='pt').to(device) |
| | attention_mask = (input_ids != tokenizer.pad_token_id).long() |
| |
|
| | |
| | with torch.no_grad(): |
| | transformer_output = model_transformer(input_ids, input_ids) |
| | |
| | initial_representation = representation_network(transformer_output) |
| | initial_representation = initial_representation[:, -1, :].unsqueeze(1) |
| | initial_state = State( |
| | representation=initial_representation, |
| | dynamics_network=dynamics_network, |
| | action_encoder=action_encoder, |
| | thought_node=root_thought_node |
| | ) |
| | |
| | mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant) |
| |
|
| | current_state = initial_state |
| | thought_sequence = [] |
| |
|
| | for _ in range(max_length // n_tokens_predict): |
| | best_actions = mcts.search_with_beam(current_state) |
| |
|
| | thought_sequence.extend(best_actions) |
| |
|
| | |
| | for action in best_actions: |
| | current_state = current_state.apply_action(action) |
| |
|
| | |
| | if len(current_state.thought_node.children) == 0: |
| | break |
| |
|
| | return thought_sequence |
| |
|
| | |
| | |
| | |
| |
|
| | def distill_model( |
| | teacher_model_name: str, |
| | student_model_name: str, |
| | dataset_name: str, |
| | config: str, |
| | distill_full_model: bool = True, |
| | query_terms: Optional[List[str]] = None, |
| | num_epochs: int = 3, |
| | batch_size: int = 4, |
| | max_length: int = 128, |
| | learning_rate: float = 5e-5, |
| | temperature: float = 2.0, |
| | save_path: str = "./distilled_model", |
| | log_dir: str = "./logs", |
| | checkpoint_dir: str = "./checkpoints", |
| | early_stopping_patience: int = 3, |
| | accumulation_steps: int = 1, |
| | max_grad_norm: float = 1.0, |
| | weight_decay: float = 0.01 |
| | ): |
| | |
| | writer = SummaryWriter(log_dir=log_dir) |
| |
|
| | |
| | print("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | print("Tokenizer loaded successfully.") |
| |
|
| | |
| | print("Loading teacher model...") |
| | teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device) |
| | print("Teacher model loaded successfully.") |
| |
|
| | if distill_full_model: |
| | |
| | print(f"Starting Full World Model Distillation into '{student_model_name}'.") |
| |
|
| | |
| | print(f"Attempting to load student model '{student_model_name}'...") |
| | try: |
| | student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) |
| | print(f"Student model '{student_model_name}' loaded successfully.") |
| | except (OSError, ValueError) as e: |
| | print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") |
| | |
| | try: |
| | student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) |
| | |
| | student.save_pretrained(save_path) |
| | tokenizer.save_pretrained(save_path) |
| | print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") |
| | except Exception as inst_e: |
| | print(f"Failed to instantiate and save student model: {inst_e}") |
| | sys.exit(1) |
| |
|
| | |
| | for param in teacher.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | print(f"Loading full dataset '{dataset_name}' with config '{config}'...") |
| | dataset = load_dataset(dataset_name, config) |
| | train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size) |
| | print("Data loaded and preprocessed successfully.") |
| |
|
| | |
| | optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) |
| | scaler = GradScaler() |
| |
|
| | |
| | criterion = nn.KLDivLoss(reduction="batchmean") |
| |
|
| | best_val_loss = float('inf') |
| | epochs_no_improve = 0 |
| |
|
| | |
| | for epoch in range(1, num_epochs + 1): |
| | print(f"\nEpoch {epoch}/{num_epochs}") |
| | print("-" * 20) |
| |
|
| | |
| | train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) |
| | print(f"Training Loss: {train_loss:.4f}") |
| | writer.add_scalar("Loss/Train", train_loss, epoch) |
| |
|
| | |
| | val_loss = validate(teacher, student, val_loader, criterion, temperature) |
| | print(f"Validation Loss: {val_loss:.4f}") |
| | writer.add_scalar("Loss/Validation", val_loss, epoch) |
| |
|
| | |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | epochs_no_improve = 0 |
| | |
| | save_checkpoint({ |
| | 'epoch': epoch, |
| | 'model_state_dict': student.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scheduler_state_dict': scheduler.state_dict(), |
| | 'scaler_state_dict': scaler.state_dict(), |
| | 'best_val_loss': best_val_loss |
| | }, checkpoint_dir, epoch) |
| | |
| | student.save_pretrained(save_path) |
| | tokenizer.save_pretrained(save_path) |
| | print(f"Best model saved at epoch {epoch}") |
| | else: |
| | epochs_no_improve += 1 |
| | print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") |
| | if epochs_no_improve >= early_stopping_patience: |
| | print("Early stopping triggered") |
| | break |
| |
|
| | |
| | scheduler.step() |
| |
|
| | writer.close() |
| | print("\nFull World Model Distillation completed.") |
| |
|
| | else: |
| | |
| | print(f"Starting Standard Language Model Distillation into '{student_model_name}'.") |
| |
|
| | if not query_terms: |
| | print("Error: --query_terms must be provided for standard language model distillation.") |
| | sys.exit(1) |
| |
|
| | |
| | print(f"Attempting to load student model '{student_model_name}'...") |
| | try: |
| | student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device) |
| | print(f"Student model '{student_model_name}' loaded successfully.") |
| | except (OSError, ValueError) as e: |
| | print(f"Student model '{student_model_name}' not found. Instantiating a new student model.") |
| | |
| | try: |
| | student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device) |
| | |
| | student.save_pretrained(save_path) |
| | tokenizer.save_pretrained(save_path) |
| | print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.") |
| | except Exception as inst_e: |
| | print(f"Failed to instantiate and save student model: {inst_e}") |
| | sys.exit(1) |
| |
|
| | |
| | for param in teacher.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | print(f"Loading custom data files: {query_terms}") |
| | custom_data = load_custom_data_from_files(query_terms) |
| | train_loader, val_loader = load_custom_data( |
| | args=argparse.Namespace(max_length=max_length), |
| | tokenizer=tokenizer, |
| | custom_data=custom_data |
| | ) |
| | print("Custom data loaded and preprocessed successfully.") |
| |
|
| | |
| | optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay) |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) |
| | scaler = GradScaler() |
| |
|
| | |
| | criterion = nn.KLDivLoss(reduction="batchmean") |
| |
|
| | best_val_loss = float('inf') |
| | epochs_no_improve = 0 |
| |
|
| | |
| | for epoch in range(1, num_epochs + 1): |
| | print(f"\nEpoch {epoch}/{num_epochs}") |
| | print("-" * 20) |
| |
|
| | |
| | train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature) |
| | print(f"Training Loss: {train_loss:.4f}") |
| | writer.add_scalar("Loss/Train", train_loss, epoch) |
| |
|
| | |
| | val_loss = validate(teacher, student, val_loader, criterion, temperature) |
| | print(f"Validation Loss: {val_loss:.4f}") |
| | writer.add_scalar("Loss/Validation", val_loss, epoch) |
| |
|
| | |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | epochs_no_improve = 0 |
| | |
| | save_checkpoint({ |
| | 'epoch': epoch, |
| | 'model_state_dict': student.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scheduler_state_dict': scheduler.state_dict(), |
| | 'scaler_state_dict': scaler.state_dict(), |
| | 'best_val_loss': best_val_loss |
| | }, checkpoint_dir, epoch) |
| | |
| | student.save_pretrained(save_path) |
| | tokenizer.save_pretrained(save_path) |
| | print(f"Best model saved at epoch {epoch}") |
| | else: |
| | epochs_no_improve += 1 |
| | print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)") |
| | if epochs_no_improve >= early_stopping_patience: |
| | print("Early stopping triggered") |
| | break |
| |
|
| | |
| | scheduler.step() |
| |
|
| | writer.close() |
| | print("\nStandard Language Model Distillation completed.") |
| |
|
| | |
| | |
| | |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.") |
| |
|
| | |
| | parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model") |
| | parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model") |
| |
|
| | |
| | parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset") |
| | parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')") |
| |
|
| | |
| | parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model") |
| |
|
| | |
| | parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation") |
| |
|
| | |
| | parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs") |
| | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") |
| | parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length") |
| | parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate") |
| | parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature") |
| |
|
| | |
| | parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model") |
| | parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs") |
| | parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints") |
| |
|
| | |
| | parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience") |
| |
|
| | |
| | parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") |
| | parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping") |
| | parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer") |
| |
|
| | return parser.parse_args() |
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | args = parse_args() |
| | print("Arguments parsed successfully.") |
| |
|
| | |
| | os.makedirs(args.save_path, exist_ok=True) |
| | os.makedirs(args.log_dir, exist_ok=True) |
| | os.makedirs(args.checkpoint_dir, exist_ok=True) |
| | print(f"Save directory created: {args.save_path}") |
| | print(f"Log directory created: {args.log_dir}") |
| | print(f"Checkpoint directory created: {args.checkpoint_dir}") |
| |
|
| | |
| | if args.distill_full_model: |
| | |
| | distill_model( |
| | teacher_model_name=args.teacher_model_name, |
| | student_model_name=args.student_model_name, |
| | dataset_name=args.dataset_name, |
| | config=args.config, |
| | distill_full_model=args.distill_full_model, |
| | query_terms=args.query_terms, |
| | num_epochs=args.num_epochs, |
| | batch_size=args.batch_size, |
| | max_length=args.max_length, |
| | learning_rate=args.learning_rate, |
| | temperature=args.temperature, |
| | save_path=args.save_path, |
| | log_dir=args.log_dir, |
| | checkpoint_dir=args.checkpoint_dir, |
| | early_stopping_patience=args.early_stopping_patience, |
| | accumulation_steps=args.accumulation_steps, |
| | max_grad_norm=args.max_grad_norm, |
| | weight_decay=args.weight_decay |
| | ) |
| | else: |
| | |
| | distill_model( |
| | teacher_model_name=args.teacher_model_name, |
| | student_model_name=args.student_model_name, |
| | dataset_name=args.dataset_name, |
| | config=args.config, |
| | distill_full_model=args.distill_full_model, |
| | query_terms=args.query_terms, |
| | num_epochs=args.num_epochs, |
| | batch_size=args.batch_size, |
| | max_length=args.max_length, |
| | learning_rate=args.learning_rate, |
| | temperature=args.temperature, |
| | save_path=args.save_path, |
| | log_dir=args.log_dir, |
| | checkpoint_dir=args.checkpoint_dir, |
| | early_stopping_patience=args.early_stopping_patience, |
| | accumulation_steps=args.accumulation_steps, |
| | max_grad_norm=args.max_grad_norm, |
| | weight_decay=args.weight_decay |
| | ) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|
| |
|