| | import argparse |
| | import math |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | import copy |
| | from torch.optim.lr_scheduler import CosineAnnealingLR |
| | from torch.cuda.amp import autocast, GradScaler |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| | from typing import List, Tuple |
| | import sys |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Train or Inference with World Model and Tree of Thought.') |
| | parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path') |
| |
|
| | parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets') |
| | parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') |
| | parser.add_argument('--batch_size', type=int, default=4, help='Batch size') |
| | parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs') |
| | parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length') |
| | parser.add_argument('--mcts_iterations', type=int, default=3, help='Number of MCTS Iterations') |
| | parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS') |
| | parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps') |
| | parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') |
| | parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay') |
| | parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight') |
| | parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight') |
| | parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping') |
| | parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models') |
| | parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance') |
| | parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Mode: train or inference') |
| | parser.add_argument('--inference_mode', type=str, choices=['world_model', 'without_world_model', 'world_model_tree_of_thought'], default='world_model_tree_of_thought', help='Inference mode') |
| | parser.add_argument('--query', type=str, default='', help='Input query for inference') |
| | parser.add_argument('--train_mode', type=str, choices=['world_model', 'language_model'], default='language_model', help='Train world model or language model only') |
| | parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search') |
| | parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step') |
| | parser.add_argument('--load_model', type=str, default=None, |
| | help='Path to load saved model. If not provided, a new model will be initialized.') |
| |
|
| | parser.add_argument('--use_custom_data', action='store_true', help='Use custom data for training') |
| |
|
| | # Determine the base directory |
| | if hasattr(sys, 'frozen') and hasattr(sys, '_MEIPASS'): |
| | # PyInstaller creates a temp folder and stores path in _MEIPASS |
| | base_dir = sys._MEIPASS |
| | elif '__file__' in globals(): |
| | # Running as a script |
| | base_dir = os.path.dirname(os.path.abspath(__file__)) |
| | else: |
| | # Running in an interactive environment (e.g., Jupyter, Colab) |
| | base_dir = os.getcwd() |
| |
|
| | default_paths = [ |
| | '/content/drive/MyDrive/lightbulb/knowledge_base.json', |
| | '/content/drive/MyDrive/lightbulb/rag_cache.json', |
| | '/content/drive/MyDrive/lightbulb/llm_training_data/llm_training_data.jsonl' |
| | ] |
| |
|
| | parser.add_argument('--custom_data_paths', nargs='+', default=default_paths, |
| | help='Paths to custom data files (relative to the script location or current working directory)') |
| |
|
| | args, unknown = parser.parse_known_args() |
| |
|
| | # Convert relative paths to absolute paths |
| | args.custom_data_paths = [os.path.abspath(os.path.join(base_dir, path)) for path in args.custom_data_paths] |
| |
|
| | return args |
| |
|
| | import json |
| | import jsonlines |
| |
|
| | def load_custom_data_from_files(file_paths): |
| | custom_data = [] |
| | for file_path in file_paths: |
| | if file_path.endswith('.json'): |
| | with open(file_path, 'r') as f: |
| | data = json.load(f) |
| | if isinstance(data, list): |
| | custom_data.extend(data) |
| | else: |
| | custom_data.append(data) |
| | elif file_path.endswith('.jsonl'): |
| | with jsonlines.open(file_path) as reader: |
| | custom_data.extend(reader) |
| | return custom_data |
| |
|
| | def preprocess_custom_data(data_list): |
| | processed_data = [] |
| | for item in data_list: |
| | # Check if the item is a string (JSON) |
| | if isinstance(item, str): |
| | try: |
| | item = json.loads(item) |
| | except json.JSONDecodeError: |
| | print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging |
| | continue # Skip this item if it's not valid JSON |
| |
|
| | # Process query and content |
| | query = item.get('query', '') |
| | content = item.get('content', '') |
| | if content == "RAG response generation failed.": |
| | content = "" |
| |
|
| | # Combine query and content |
| | combined_text = f"Query: {query} Content: {content}" |
| |
|
| | # Process numerical data (assuming these are available in the item dict) |
| | episode_reward = item.get('episode_reward', 0) |
| | loss = item.get('loss', 0) |
| | cosine_similarity = item.get('cosine_similarity', 0) |
| | rag_performance = item.get('rag_performance', 0) |
| | ranking_model_performance = item.get('ranking_model_performance', 0) |
| |
|
| | # Create a dictionary with processed data |
| | processed_item = { |
| | 'text': combined_text, |
| | 'episode_reward': episode_reward, |
| | 'loss': loss, |
| | 'cosine_similarity': cosine_similarity, |
| | 'rag_performance': rag_performance, |
| | 'ranking_model_performance': ranking_model_performance |
| | } |
| |
|
| | processed_data.append(processed_item) |
| |
|
| | return processed_data |
| |
|
| | def load_custom_data(args, tokenizer, custom_data): |
| | # Preprocess the custom data |
| | processed_data = preprocess_custom_data(custom_data) |
| |
|
| | # Create a custom dataset |
| | class CustomDataset(torch.utils.data.Dataset): |
| | def __init__(self, data, tokenizer, max_length): |
| | self.data = data |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | encoded = self.tokenizer.encode_plus( |
| | item['text'], |
| | max_length=self.max_length, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors='pt' |
| | ) |
| | return { |
| | 'input_ids': encoded['input_ids'].squeeze(), |
| | 'attention_mask': encoded['attention_mask'].squeeze(), |
| | 'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float), |
| | 'loss': torch.tensor(item['loss'], dtype=torch.float), |
| | 'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float), |
| | 'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float), |
| | 'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float) |
| | } |
| |
|
| | # Create dataset and dataloader |
| | dataset = CustomDataset(processed_data, tokenizer, args.max_length) |
| |
|
| | # Split the dataset into train and eval |
| | train_size = int(0.8 * len(dataset)) |
| | eval_size = len(dataset) - train_size |
| | train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size]) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=4 |
| | ) |
| | eval_loader = DataLoader( |
| | eval_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=4 |
| | ) |
| |
|
| | return train_loader, eval_loader |
| |
|
| |
|
| |
|
| | def load_data(args, tokenizer): |
| | # Load the dataset |
| | dataset = load_dataset(args.dataset_name, args.dataset_config) |
| |
|
| | # Ensure the tokenizer has a padding token |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | def tokenize_function(examples): |
| | return tokenizer(examples['text'], truncation=True, max_length=args.max_length) |
| |
|
| | tokenized_datasets = dataset.map( |
| | tokenize_function, |
| | batched=True, |
| | num_proc=4, |
| | remove_columns=dataset['train'].column_names, |
| | ) |
| |
|
| | # Build inputs and labels for language modeling |
| | block_size = args.max_length |
| |
|
| | def group_texts(examples): |
| | # Concatenate all texts |
| | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
| | total_length = len(concatenated_examples['input_ids']) |
| | # We drop the small remainder |
| | total_length = (total_length |
| | # Split by chunks of block_size |
| | result = { |
| | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| | for k, t in concatenated_examples.items() |
| | } |
| | result['labels'] = result['input_ids'].copy() |
| | return result |
| |
|
| | lm_datasets = tokenized_datasets.map( |
| | group_texts, |
| | batched=True, |
| | num_proc=4, |
| | ) |
| |
|
| | # Create DataLoader |
| | train_dataset = lm_datasets['train'] |
| | eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test'] |
| |
|
| | def data_collator(data): |
| | return { |
| | 'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long), |
| | 'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long) |
| | } |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | shuffle=True, |
| | batch_size=args.batch_size, |
| | collate_fn=data_collator, |
| | pin_memory=True, # Speeds up transfer to GPU |
| | num_workers=4 |
| | ) |
| | eval_loader = DataLoader( |
| | eval_dataset, |
| | shuffle=False, |
| | batch_size=args.batch_size, |
| | collate_fn=data_collator, |
| | pin_memory=True, |
| | num_workers=4 |
| | ) |
| |
|
| | return train_loader, eval_loader |
| |
|
| | def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch): |
| | """ |
| | Save all models to the specified directory. |
| | |
| | Args: |
| | transformer_model (nn.Module): Transformer model. |
| | representation_network (nn.Module): Representation network. |
| | dynamics_network (nn.Module): Dynamics network. |
| | prediction_network (nn.Module): Prediction network. |
| | action_encoder (nn.Module): Action encoder. |
| | save_dir (str): Directory to save the models. |
| | epoch (int): Current epoch number. |
| | """ |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt')) |
| | torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt')) |
| | torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt')) |
| | torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt')) |
| | torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt')) |
| |
|
| | print(f"All models saved for epoch {epoch}.") |
| |
|
| | class RotaryPositionalEncoding(nn.Module): |
| | def __init__(self, d_model): |
| | super(RotaryPositionalEncoding, self).__init__() |
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)) |
| | self.register_buffer('inv_freq', inv_freq) |
| |
|
| | def forward(self, x): |
| | seq_len, batch_size, _ = x.size() |
| | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| | sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) |
| | sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2) |
| | cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2) |
| |
|
| | x1 = x[..., 0::2] |
| | x2 = x[..., 1::2] |
| |
|
| | # Apply rotation |
| | x_rotated = torch.zeros_like(x) |
| | x_rotated[..., 0::2] = x1 * cos - x2 * sin |
| | x_rotated[..., 1::2] = x1 * sin + x2 * cos |
| |
|
| | return x_rotated |
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__(self, d_model, num_heads): |
| | super(MultiHeadAttention, self).__init__() |
| | assert d_model % num_heads == 0, "d_model must be divisible by num_heads" |
| | self.d_k = d_model |
| | self.num_heads = num_heads |
| | self.linear_q = nn.Linear(d_model, d_model) |
| | self.linear_k = nn.Linear(d_model, d_model) |
| | self.linear_v = nn.Linear(d_model, d_model) |
| | self.linear_out = nn.Linear(d_model, d_model) |
| |
|
| | def forward(self, query, key, value, mask=None): |
| | batch_size = query.size(0) |
| | query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
| | key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
| | value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
| |
|
| | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | if mask is not None: |
| | scores = scores.masked_fill(mask == 0, -1e4) |
| | attn = F.softmax(scores, dim=-1) |
| | output = torch.matmul(attn, value) |
| |
|
| | output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) |
| | return self.linear_out(output) |
| |
|
| | class MoE(nn.Module): |
| | def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1): |
| | super(MoE, self).__init__() |
| | self.num_experts = num_experts |
| | self.top_k = top_k |
| | self.experts = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.Linear(d_model, d_ff), |
| | nn.GELU() if i % 2 == 0 else nn.SiLU(), |
| | nn.Linear(d_ff, d_model) |
| | ) |
| | for i in range(num_experts) |
| | ]) |
| | self.gate = nn.Linear(d_model, num_experts) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | batch_size, seq_len, d_model = x.size() |
| | # Compute gating scores |
| | gate_scores = self.gate(x) # (batch_size, seq_len, num_experts) |
| | top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k) |
| | top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k) |
| |
|
| | # Initialize output |
| | output = torch.zeros_like(x) |
| |
|
| | # Flatten batch and sequence dimensions |
| | x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model) |
| | output_flat = output.view(-1, d_model) |
| | top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k) |
| | top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k) |
| |
|
| | for k in range(self.top_k): |
| | expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len) |
| | expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len) |
| | for e in range(self.num_experts): |
| | mask = (expert_idx_flat == e) # Boolean mask |
| | if mask.any(): |
| | x_masked = x_flat[mask] # Select tokens for expert e |
| | expert_output = self.experts[e](x_masked) # Apply expert e |
| | output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output |
| |
|
| | output = output_flat.view(batch_size, seq_len, d_model) |
| | return self.dropout(output) |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2): |
| | super(TransformerBlock, self).__init__() |
| | self.self_attention = MultiHeadAttention(d_model, num_heads) |
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.cross_attention = MultiHeadAttention(d_model, num_heads) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout) |
| | self.norm3 = nn.LayerNorm(d_model) |
| |
|
| | def forward(self, x, mask=None, enc_output=None, enc_mask=None): |
| | # Self-attention |
| | attn_output = self.self_attention(x, x, x, mask) |
| | x = self.norm1(x + attn_output) |
| | # Cross-attention (only in decoder) |
| | if enc_output is not None: |
| | cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask) |
| | x = self.norm2(x + cross_attn_output) |
| | # Feedforward/MoE |
| | moe_output = self.moe(x) |
| | return self.norm3(x + moe_output) |
| |
|
| | class Transformer(nn.Module): |
| | def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2): |
| | super(Transformer, self).__init__() |
| | self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1) |
| | self.rotary_positional_encoding = RotaryPositionalEncoding(d_model) |
| | self.encoder_layers = nn.ModuleList( |
| | [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)] |
| | ) |
| | self.decoder_layers = nn.ModuleList( |
| | [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)] |
| | ) |
| | self.output_layer = nn.Linear(d_model, output_dim) |
| | self.d_model = d_model |
| |
|
| | def forward(self, src, tgt, src_mask=None, tgt_mask=None): |
| | # Encoder |
| | src = self.embedding(src) * math.sqrt(self.d_model) |
| | src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model) |
| | src = self.rotary_positional_encoding(src) |
| | src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model) |
| | for layer in self.encoder_layers: |
| | src = layer(src, src_mask) |
| |
|
| | # Decoder |
| | tgt = self.embedding(tgt) * math.sqrt(self.d_model) |
| | tgt = tgt.transpose(0, 1) |
| | tgt = self.rotary_positional_encoding(tgt) |
| | tgt = tgt.transpose(0, 1) |
| | for layer in self.decoder_layers: |
| | tgt = layer(tgt, tgt_mask, src, src_mask) |
| | output = self.output_layer(tgt) |
| | return output |
| |
|
| | def generate_with_beam_search(self, src, tokenizer, beam_size=5, max_length=20, n_tokens_predict=3, temperature=1.0): |
| | """ |
| | Generate sequences using beam search with multi-token prediction. |
| | |
| | Args: |
| | src (torch.Tensor): Source input tensor of shape (batch_size, seq_len) |
| | tokenizer: Tokenizer to access special tokens |
| | beam_size (int): Size of the beam for beam search |
| | max_length (int): Maximum length of the generated sequence |
| | n_tokens_predict (int): Number of tokens to predict at each step |
| | temperature (float): Temperature parameter for softmax |
| | |
| | Returns: |
| | List[Tuple[torch.Tensor, float]]: List of (sequence, score) tuples |
| | """ |
| | batch_size = src.size(0) |
| | device = src.device |
| | vocab_size = self.output_layer.out_features |
| |
|
| | # Encode the source |
| | src_enc = self.encode(src) |
| |
|
| | # Initialize beam |
| | beam = [(torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=device), |
| | 0.0, # log probability |
| | torch.zeros(batch_size, device=device), # cumulative entropy |
| | torch.zeros(batch_size, device=device))] # cumulative variance |
| |
|
| | for _ in range(max_length |
| | all_candidates = [] |
| | for seq, score, cum_entropy, cum_variance in beam: |
| | if seq[:, -1].item() == tokenizer.eos_token_id: |
| | all_candidates.append((seq, score, cum_entropy, cum_variance)) |
| | continue |
| |
|
| | # Predict next n tokens |
| | logits = self.predict_next_n_tokens(src_enc, seq, n_tokens_predict) |
| |
|
| | # Calculate probabilities, entropy, and variance |
| | probs = F.softmax(logits / temperature, dim=-1) |
| | entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| | variance = torch.var(probs, dim=-1) |
| |
|
| | # Sample top-k tokens for each position |
| | topk_probs, topk_indices = torch.topk(probs, k=beam_size, dim=-1) |
| |
|
| | # Generate all possible continuations |
| | for i in range(beam_size ** n_tokens_predict): |
| | indices = [i // (beam_size ** j) % beam_size for j in range(n_tokens_predict)] |
| | new_tokens = topk_indices[:, range(n_tokens_predict), indices] |
| | new_seq = torch.cat([seq, new_tokens], dim=-1) |
| | new_score = score + torch.sum(torch.log(topk_probs[:, range(n_tokens_predict), indices])) |
| | new_entropy = cum_entropy + torch.sum(entropy[:, indices]) |
| | new_variance = cum_variance + torch.sum(variance[:, indices]) |
| |
|
| | all_candidates.append((new_seq, new_score, new_entropy, new_variance)) |
| |
|
| | # Select top beam_size candidates |
| | beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:beam_size] |
| |
|
| | # Stop if all beams have ended |
| | if all(seq[:, -1].item() == tokenizer.eos_token_id for seq, _, _, _ in beam): |
| | break |
| |
|
| | return [(seq, score) for seq, score, _, _ in beam] |
| |
|
| | def encode(self, src): |
| | src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| | src_emb = src_emb.transpose(0, 1) |
| | src_emb = self.rotary_positional_encoding(src_emb) |
| | src_emb = src_emb.transpose(0, 1) |
| | src_enc = src_emb |
| | for layer in self.encoder_layers: |
| | src_enc = layer(src_enc) |
| | return src_enc |
| |
|
| | def predict_next_n_tokens(self, src_enc, tgt_seq, n_tokens): |
| | tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model) |
| | tgt_emb = tgt_emb.transpose(0, 1) |
| | tgt_emb = self.rotary_positional_encoding(tgt_emb) |
| | tgt_emb = tgt_emb.transpose(0, 1) |
| | tgt_dec = tgt_emb |
| | for layer in self.decoder_layers: |
| | tgt_dec = layer(tgt_dec, None, src_enc, None) |
| | output = self.output_layer(tgt_dec[:, -1:]) |
| | return output.repeat(1, n_tokens, 1) |
| |
|
| | # Objective Functions |
| |
|
| | class InfoNCE_Loss(nn.Module): |
| | def __init__(self, temperature=0.07): |
| | super(InfoNCE_Loss, self).__init__() |
| | self.temperature = temperature |
| | self.cross_entropy = nn.CrossEntropyLoss() |
| |
|
| | def forward(self, z_i, z_j): |
| | """ |
| | Args: |
| | z_i (torch.Tensor): Flattened representations from view i, shape (2n, embed_dim) |
| | z_j (torch.Tensor): Flattened representations from view j, shape (2n, embed_dim) |
| | |
| | Returns: |
| | torch.Tensor: InfoNCE loss |
| | """ |
| | n = z_i.size(0) |
| | z = torch.cat([z_i, z_j], dim=0) # Shape: (2n, embed_dim) |
| |
|
| | z = F.normalize(z, dim=1) |
| | similarity_matrix = torch.matmul(z, z.T) # Shape: (2n, 2n) |
| |
|
| | # Create a mask to exclude self-similarity |
| | mask = torch.eye(2 * n, device=z.device, dtype=torch.bool) |
| | similarity_matrix = similarity_matrix.masked_fill(mask, -1e4) # Use a manageable negative value |
| |
|
| | # Create labels for contrastive learning |
| | labels = torch.arange(n, device=z.device) |
| | labels = torch.cat([labels + n, labels], dim=0) # Shape: (2n,) |
| |
|
| | # Apply temperature scaling |
| | similarity_matrix /= self.temperature |
| |
|
| | # Compute cross-entropy loss |
| | loss = self.cross_entropy(similarity_matrix, labels) |
| | return loss |
| |
|
| | class CovarianceRegularization(nn.Module): |
| | def __init__(self, lambda_reg=1e-3): |
| | super(CovarianceRegularization, self).__init__() |
| | self.lambda_reg = lambda_reg |
| |
|
| | def forward(self, embeddings): |
| | """ |
| | Args: |
| | embeddings (torch.Tensor): Embedding tensor, shape (batch_size, embed_dim) |
| | |
| | Returns: |
| | torch.Tensor: Covariance regularization loss |
| | """ |
| | batch_size, embed_dim = embeddings.size() |
| | mean = embeddings.mean(dim=0) |
| | embeddings_centered = embeddings - mean |
| | cov = (embeddings_centered.T @ embeddings_centered) / (batch_size - 1) |
| | cov_loss = torch.sum(cov ** 2) - torch.sum(torch.diag(cov) ** 2) |
| | return self.lambda_reg * cov_loss |
| |
|
| | class DynamicsPerformanceLoss(nn.Module): |
| | def __init__(self, lambda_var=1e-3): |
| | super(DynamicsPerformanceLoss, self).__init__() |
| | self.lambda_var = lambda_var |
| |
|
| | def forward(self, true_next_state, predicted_next_state): |
| | """ |
| | Args: |
| | true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim) |
| | predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim) |
| | |
| | Returns: |
| | torch.Tensor: Dynamics performance loss |
| | """ |
| | mse_loss = F.mse_loss(predicted_next_state, true_next_state) |
| | variance_loss = torch.var(predicted_next_state, dim=0).mean() |
| | return mse_loss + self.lambda_var * variance_loss |
| |
|
| | class ThoughtConsistencyLoss(nn.Module): |
| | def __init__(self): |
| | super(ThoughtConsistencyLoss, self).__init__() |
| |
|
| | def forward(self, true_next_state, perturbed_next_state): |
| | """ |
| | Args: |
| | true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim) |
| | perturbed_next_state (torch.Tensor): Perturbed next state, shape (batch_size, state_dim) |
| | |
| | Returns: |
| | torch.Tensor: Thought-consistency loss |
| | """ |
| | return F.mse_loss(true_next_state, perturbed_next_state) |
| |
|
| | class PolicyValueJointLoss(nn.Module): |
| | def __init__(self, lambda_value=0.5): |
| | super(PolicyValueJointLoss, self).__init__() |
| | self.lambda_value = lambda_value |
| | self.cross_entropy = nn.CrossEntropyLoss() |
| | self.mse_loss = nn.MSELoss() |
| |
|
| | def forward(self, policy_logits, true_policy, value_pred, true_value): |
| | """ |
| | Args: |
| | policy_logits (torch.Tensor): Logits from the policy network, shape (batch_size * seq_len, num_actions) |
| | true_policy (torch.Tensor): Ground truth policy, shape (batch_size * seq_len, num_actions) |
| | value_pred (torch.Tensor): Predicted values, shape (batch_size * seq_len) |
| | true_value (torch.Tensor): Ground truth values, shape (batch_size * seq_len) |
| | |
| | Returns: |
| | torch.Tensor: Combined policy and value loss |
| | """ |
| | policy_logits = policy_logits.reshape(-1, policy_logits.size(-1)) |
| | true_policy = true_policy.reshape(-1, true_policy.size(-1)) |
| | value_pred = value_pred.reshape(-1) |
| | true_value = true_value.reshape(-1) |
| |
|
| |
|
| | policy_loss = self.cross_entropy(policy_logits, true_policy.argmax(dim=1)) |
| | value_loss = self.mse_loss(value_pred, true_value) |
| | return policy_loss + self.lambda_value * value_loss |
| |
|
| | class ActionDiversityReward(nn.Module): |
| | def __init__(self, lambda_div=1e-3): |
| | super(ActionDiversityReward, self).__init__() |
| | self.lambda_div = lambda_div |
| |
|
| | def forward(self, action_embeddings): |
| | """ |
| | Args: |
| | action_embeddings (torch.Tensor): Embeddings of actions, shape (batch_size, embed_dim) |
| | |
| | Returns: |
| | torch.Tensor: Action diversity loss |
| | """ |
| | similarity_matrix = F.cosine_similarity(action_embeddings.unsqueeze(1), action_embeddings.unsqueeze(0), dim=2) |
| | # Zero out self-similarity |
| | similarity_matrix = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(action_embeddings.device) |
| | diversity_loss = torch.sum(similarity_matrix ** 2) |
| | return self.lambda_div * diversity_loss |
| |
|
| | class ExpectedThoughtValueLoss(nn.Module): |
| | def __init__(self): |
| | super(ExpectedThoughtValueLoss, self).__init__() |
| |
|
| | def forward(self, mcts_best_values): |
| | """ |
| | Args: |
| | mcts_best_values (torch.Tensor): Best values from MCTS, shape (batch_size) |
| | |
| | Returns: |
| | torch.Tensor: ETV loss |
| | """ |
| | return -mcts_best_values.mean() |
| |
|
| | class ExplorationRegularization(nn.Module): |
| | def __init__(self, lambda_expl=1e-3): |
| | super(ExplorationRegularization, self).__init__() |
| | self.lambda_expl = lambda_expl |
| |
|
| | def forward(self, visit_counts): |
| | """ |
| | Args: |
| | visit_counts (torch.Tensor): Visit counts for actions, shape (batch_size, num_actions) |
| | |
| | Returns: |
| | torch.Tensor: Exploration regularization loss |
| | """ |
| | reward = torch.sum(1.0 / (visit_counts + 1), dim=-1) |
| | return self.lambda_expl * reward.mean() |
| |
|
| | class KL_DivergenceLoss(nn.Module): |
| | def __init__(self): |
| | super(KL_DivergenceLoss, self).__init__() |
| |
|
| | def forward(self, old_policy, new_policy): |
| | """ |
| | Args: |
| | old_policy (torch.Tensor): Old policy probabilities, shape (batch_size, num_actions) |
| | new_policy (torch.Tensor): New policy probabilities, shape (batch_size, num_actions) |
| | |
| | Returns: |
| | torch.Tensor: KL divergence loss |
| | """ |
| | kl_div = F.kl_div(new_policy.log(), old_policy, reduction='batchmean') |
| | return kl_div |
| |
|
| | # MuZero Components |
| |
|
| | class ActionEncoder(nn.Module): |
| | def __init__(self, action_vocab_size, embed_dim): |
| | super(ActionEncoder, self).__init__() |
| | self.embedding = nn.Embedding(action_vocab_size, embed_dim) |
| |
|
| | def forward(self, action_indices): |
| | """ |
| | Args: |
| | action_indices (torch.Tensor): Tensor of shape (batch_size, seq_len) |
| | |
| | Returns: |
| | torch.Tensor: Encoded actions of shape (batch_size, seq_len, embed_dim) |
| | """ |
| | return self.embedding(action_indices) |
| |
|
| | class RepresentationNetwork(nn.Module): |
| | def __init__(self, vocab_dim, d_model, state_dim): |
| | super(RepresentationNetwork, self).__init__() |
| | self.proj = nn.Linear(vocab_dim, d_model) # Project from vocab_dim to d_model |
| | self.linear = nn.Linear(d_model, state_dim) # Project from d_model to state_dim |
| | self.norm = nn.LayerNorm(state_dim) |
| |
|
| | def forward(self, transformer_output): |
| | """ |
| | Args: |
| | transformer_output (torch.Tensor): Shape (batch_size, seq_len, vocab_dim) |
| | |
| | Returns: |
| | torch.Tensor: Encoded state of shape (batch_size, seq_len, state_dim) |
| | """ |
| | # First project down from vocab_dim to d_model |
| | projected_output = self.proj(transformer_output) # Shape: (batch_size, seq_len, d_model) |
| | # Then project down from d_model to state_dim |
| | state = self.linear(projected_output) # Shape: (batch_size, seq_len, state_dim) |
| | state = self.norm(state) # Shape: (batch_size, seq_len, state_dim) |
| | return state |
| |
|
| |
|
| | class DynamicsNetwork(nn.Module): |
| | def __init__(self, state_dim, action_dim, hidden_dim): |
| | super(DynamicsNetwork, self).__init__() |
| | self.rms_norm = nn.LayerNorm(state_dim) |
| | self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) |
| | self.activation = nn.GELU() |
| | self.fc2 = nn.Linear(hidden_dim, state_dim) |
| |
|
| | def forward(self, state, action): |
| | """ |
| | Args: |
| | state (torch.Tensor): Current state, shape (batch_size, seq_len, state_dim) |
| | action (torch.Tensor): Action embedding, shape (batch_size, seq_len, action_dim) |
| | |
| | Returns: |
| | torch.Tensor: Predicted next state, shape (batch_size, seq_len, state_dim) |
| | """ |
| | norm_state = self.rms_norm(state) |
| | combined = torch.cat([norm_state, action], dim=-1) |
| | hidden = self.activation(self.fc1(combined)) |
| | next_state = self.fc2(hidden) |
| | return next_state |
| |
|
| | class PredictionNetwork(nn.Module): |
| | def __init__(self, state_dim, action_vocab_size, value_dim): |
| | super(PredictionNetwork, self).__init__() |
| | self.state_dim = state_dim |
| | self.rms_norm = nn.LayerNorm(state_dim) |
| | self.policy_head = nn.Linear(state_dim, action_vocab_size) # Output size is action_vocab_size |
| | self.value_head = nn.Linear(state_dim, value_dim) |
| |
|
| | def forward(self, state): |
| | """ |
| | Args: |
| | state (torch.Tensor): State representation, shape (batch_size, state_dim) |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Policy logits and value estimates |
| | """ |
| | norm_state = self.rms_norm(state) |
| | policy_logits = self.policy_head(norm_state) # Shape: (batch_size, action_vocab_size) |
| | value_estimates = self.value_head(norm_state).squeeze(-1) # Shape: (batch_size) |
| | return policy_logits, value_estimates |
| |
|
| |
|
| |
|
| |
|
| | class MCTSNode: |
| | __slots__ = [ |
| | 'state', |
| | 'parent', |
| | 'action', |
| | 'children', |
| | 'visit_count', |
| | 'value_sum', |
| | 'prior', |
| | 'cached_policy', |
| | 'cached_value', |
| | 'thought_node', |
| | 'entropy', |
| | 'variance' |
| | ] |
| |
|
| | def __init__(self, state, thought_node, parent=None, action=None): |
| | self.state = state |
| | self.thought_node = thought_node |
| | self.parent = parent |
| | self.action = action |
| | self.children = {} |
| | self.visit_count = 0 |
| | self.value_sum = 0.0 |
| | self.prior = 0.0 |
| | self.cached_policy = None |
| | self.cached_value = None |
| | self.entropy = 0.0 |
| | self.variance = 0.0 |
| |
|
| | def expand(self, priors): |
| | for child_thought_node in self.thought_node.children: |
| | action = child_thought_node.name |
| | if action not in self.children: |
| | child_state = self.state.apply_action(action) |
| | child_node = MCTSNode( |
| | state=child_state, |
| | thought_node=child_thought_node, |
| | parent=self, |
| | action=action |
| | ) |
| | child_node.prior = priors.get(action, 1.0 / len(self.thought_node.children)) |
| | self.children[action] = child_node |
| |
|
| | def is_leaf(self): |
| | return len(self.children) == 0 |
| |
|
| | def ucb_score(self, total_visits, exploration_constant=math.sqrt(2)): |
| | if self.visit_count == 0: |
| | return float('inf') # Ensure unvisited nodes are selected first |
| | avg_value = self.value_sum / self.visit_count |
| | exploration_term = exploration_constant * self.prior * math.sqrt(total_visits) / (1 + self.visit_count) |
| | entropy_term = -0.1 * self.entropy # Slightly prefer lower entropy |
| | variance_term = 0.05 * self.variance # Slightly prefer higher variance |
| | return avg_value + exploration_term + entropy_term + variance_term |
| |
|
| |
|
| | class MCTS: |
| | def __init__(self, prediction_network, dynamics_network, action_encoder, num_iterations=10, exploration_constant=math.sqrt(2), beam_size=5, n_tokens_predict=3): |
| | self.prediction_network = prediction_network |
| | self.dynamics_network = dynamics_network |
| | self.action_encoder = action_encoder |
| | self.num_iterations = num_iterations |
| | self.exploration_constant = exploration_constant |
| | self.beam_size = beam_size |
| | self.n_tokens_predict = n_tokens_predict |
| | self.cache = {} |
| |
|
| | def search_with_beam(self, root_state): |
| | root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node) |
| |
|
| | # Evaluate the root node and backpropagate |
| | value_estimate = self.evaluate(root_node) # Evaluate and expand root_node |
| | self.backpropagate(root_node, value_estimate) # Backpropagate the value |
| |
|
| | beam = [(root_node, 0.0, 0.0, 0.0, [])] # (node, score, cum_entropy, cum_variance, action_sequence) |
| |
|
| | for iteration in range(self.num_iterations): |
| | all_candidates = [] |
| | for node, score, cum_entropy, cum_variance, action_sequence in beam: |
| | if node.is_leaf(): |
| | value_estimate = self.evaluate(node) |
| | self.backpropagate(node, value_estimate) # Backpropagate after evaluation |
| | if len(node.children) == 0: |
| | continue # No children to expand |
| |
|
| | total_visits = sum(child.visit_count for child in node.children.values()) |
| | # Select top actions based on UCB score |
| | sorted_children = sorted( |
| | node.children.items(), |
| | key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant), |
| | reverse=True |
| | )[:self.beam_size] |
| |
|
| | for selected_action, selected_node in sorted_children: |
| | current_node = selected_node |
| | current_sequence = action_sequence + [selected_action] |
| | current_score = score |
| | current_entropy = cum_entropy + selected_node.entropy |
| | current_variance = cum_variance + selected_node.variance |
| |
|
| | # Predict n_tokens_predict actions |
| | for _ in range(self.n_tokens_predict): |
| | if current_node.is_leaf(): |
| | value_estimate = self.evaluate(current_node) |
| | self.backpropagate(current_node, value_estimate) # Backpropagate after evaluation |
| | if len(current_node.children) == 0: |
| | break # No more actions |
| | total_visits = sum(child.visit_count for child in current_node.children.values()) |
| | next_action, next_node = max( |
| | current_node.children.items(), |
| | key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant) |
| | ) |
| | current_sequence.append(next_action) |
| |
|
| | # Prevent division by zero by ensuring visit_count > 0 |
| | if next_node.visit_count > 0: |
| | current_score += next_node.value_sum / next_node.visit_count |
| | else: |
| | # Assign a default value or handle the zero division case |
| | current_score += 0.0 # Alternatively, use a small epsilon or skip |
| |
|
| | current_entropy += next_node.entropy |
| | current_variance += next_node.variance |
| | current_node = next_node |
| |
|
| | all_candidates.append((current_node, current_score, current_entropy, current_variance, current_sequence)) |
| |
|
| | if not all_candidates: |
| | break # No more candidates to expand |
| |
|
| | # Select top beam_size candidates |
| | beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:self.beam_size] |
| | print(f"Iteration {iteration + 1}: Beam size after sorting: {len(beam)}") # Debug |
| |
|
| | if beam: |
| | best_sequence = beam[0][4] |
| | return best_sequence |
| | else: |
| | return [] |
| |
|
| |
|
| |
|
| | def search(self, root_state): |
| | root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node) |
| |
|
| | for _ in range(self.num_iterations): |
| | node = self.select(root_node) |
| | value = self.evaluate(node) |
| | self.backpropagate(node, value) |
| |
|
| | return self.best_action_sequence(root_node) |
| |
|
| | def select(self, node): |
| | while not node.is_leaf(): |
| | total_visits = sum(child.visit_count for child in node.children.values()) |
| | _, node = max( |
| | node.children.items(), |
| | key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant) |
| | ) |
| | return node |
| |
|
| | def evaluate(self, node): |
| | # Extract the last time step |
| | state_representation = node.state.representation[:, -1, :] # Shape: (batch_size=1, state_dim) |
| | print(f"Evaluating node with state_representation shape: {state_representation.shape}") # Debug |
| | policy_logits, value_estimate = self.prediction_network(state_representation) |
| | print(f"Policy logits shape: {policy_logits.shape}, Value estimate shape: {value_estimate.shape}") # Debug |
| | value_estimate = value_estimate.item() # Now safe as batch_size=1 |
| |
|
| | policy_probs = F.softmax(policy_logits, dim=-1).squeeze(0) # Shape: (action_vocab_size,) |
| | print(f"Policy probabilities shape: {policy_probs.shape}") # Debug |
| |
|
| | priors = {} |
| | for child in node.thought_node.children: |
| | action_name = child.name |
| | action_idx = action_to_index.get(action_name, None) |
| | if action_idx is not None and action_idx < policy_probs.size(0): |
| | priors[action_name] = policy_probs[action_idx].item() |
| | else: |
| | priors[action_name] = 1.0 / len(node.thought_node.children) |
| |
|
| | node.expand(priors) |
| |
|
| | # Calculate entropy and variance |
| | entropy = -torch.sum(policy_probs * torch.log(policy_probs + 1e-9)) |
| | variance = torch.var(policy_probs) |
| | node.entropy = entropy.item() |
| | node.variance = variance.item() |
| |
|
| | print(f"Node entropy: {node.entropy}, variance: {node.variance}") # Debug |
| |
|
| | return value_estimate # Return the value estimate for backpropagation |
| |
|
| |
|
| | def backpropagate(self, node, value): |
| | while node is not None: |
| | node.visit_count += 1 |
| | node.value_sum += value |
| | node = node.parent |
| |
|
| | def best_action_sequence(self, root_node): |
| | sequences = [] |
| | self._generate_sequences(root_node, [], sequences) |
| |
|
| | # Score sequences based on visit counts, entropy, and variance |
| | scored_sequences = [] |
| | for seq in sequences: |
| | score = sum(node.visit_count for node in seq) |
| | entropy = sum(node.entropy for node in seq) |
| | variance = sum(node.variance for node in seq) |
| | adjusted_score = score - 0.1 * entropy + 0.05 * variance |
| | scored_sequences.append((seq, adjusted_score)) |
| |
|
| | # Sort sequences by adjusted score and select top beam_size |
| | best_sequences = sorted(scored_sequences, key=lambda x: x[1], reverse=True)[:self.beam_size] |
| |
|
| | # Return the actions of the best sequence |
| | best_sequence = best_sequences[0][0] |
| | return [node.action for node in best_sequence[1:self.n_tokens_predict+1]] # Exclude root node |
| |
|
| | def _generate_sequences(self, node, current_sequence, sequences): |
| | current_sequence.append(node) |
| | if len(current_sequence) > self.n_tokens_predict or not node.children: |
| | sequences.append(current_sequence) |
| | else: |
| | for child in node.children.values(): |
| | self._generate_sequences(child, current_sequence.copy(), sequences) |
| |
|
| | class State: |
| | def __init__(self, representation, dynamics_network, action_encoder, thought_node): |
| | self.representation = representation |
| | self.dynamics_network = dynamics_network |
| | self.action_encoder = action_encoder |
| | self.thought_node = thought_node |
| |
|
| | def apply_action(self, action): |
| | next_thought_node = None |
| | for child in self.thought_node.children: |
| | if child.name == action: |
| | next_thought_node = child |
| | break |
| | if next_thought_node is None: |
| | raise ValueError(f"Action '{action}' is not valid from the current thought node.") |
| |
|
| | # Adjust action_index and action_embedding shapes |
| | action_index = torch.tensor([action_to_index[action]], device=self.representation.device) |
| | action_embedding = self.action_encoder(action_index) # Shape: (batch_size=1, action_dim) |
| |
|
| | # Extract the last time step of the state |
| | state = self.representation[:, -1, :] # Shape: (batch_size, state_dim) |
| |
|
| | # Ensure action_embedding matches the state dimension |
| | next_state_representation = self.dynamics_network(state, action_embedding) # Shape: (batch_size, state_dim) |
| |
|
| | # Append the new state to the representation history |
| | new_representation = torch.cat([self.representation, next_state_representation.unsqueeze(1)], dim=1) # Shape: (batch_size, seq_len+1, state_dim) |
| |
|
| | return State( |
| | representation=new_representation, |
| | dynamics_network=self.dynamics_network, |
| | action_encoder=self.action_encoder, |
| | thought_node=next_thought_node |
| | ) |
| |
|
| | class PPOAgent: |
| | def __init__(self, policy_network, optimizer, clip_epsilon=0.2, entropy_coef=0.01, value_coef=0.5): |
| | self.policy_network = policy_network |
| | self.optimizer = optimizer |
| | self.clip_epsilon = clip_epsilon |
| | self.entropy_coef = entropy_coef |
| | self.value_coef = value_coef |
| |
|
| | def compute_loss(self, states, old_log_probs, actions, returns, advantages): |
| | # Get policy logits and value estimates |
| | policy_logits, value_estimates = self.policy_network(states) |
| |
|
| | # Flatten all tensors |
| | policy_logits = policy_logits.reshape(-1, policy_logits.size(-1)) |
| | value_estimates = value_estimates.reshape(-1) |
| | actions = actions.reshape(-1) |
| | old_log_probs = old_log_probs.reshape(-1) |
| | returns = returns.reshape(-1) |
| | advantages = advantages.reshape(-1) |
| |
|
| | # Ensure all tensors have the same first dimension |
| | assert policy_logits.size(0) == value_estimates.size(0) == actions.size(0) == old_log_probs.size(0) == returns.size(0) == advantages.size(0), "Tensor sizes mismatch" |
| |
|
| | # Compute new log probabilities |
| | new_log_probs_all = F.log_softmax(policy_logits, dim=-1) |
| | new_log_probs = new_log_probs_all.gather(1, actions.unsqueeze(-1)).squeeze(-1) |
| |
|
| | # Compute ratios |
| | ratios = torch.exp(new_log_probs - old_log_probs) |
| |
|
| | # PPO surrogate loss |
| | surr1 = ratios * advantages |
| | surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages |
| | policy_loss = -torch.min(surr1, surr2).mean() |
| |
|
| | # Value loss |
| | value_loss = F.mse_loss(value_estimates, returns) |
| |
|
| | # Entropy loss |
| | entropy = -(new_log_probs * torch.exp(new_log_probs)).mean() |
| |
|
| | # Total loss |
| | total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy |
| | return total_loss |
| |
|
| | # Tree of Thought Components |
| |
|
| | class ThoughtNode: |
| | def __init__(self, name): |
| | self.name = name |
| | self.children = [] |
| | self.parent = None |
| |
|
| | def add_child(self, child_node): |
| | child_node.parent = self |
| | self.children.append(child_node) |
| |
|
| | # Function to build the Tree of Thought from your detailed structure |
| | def build_tree_of_thought(): |
| | # Create the root node |
| | root = ThoughtNode('Problem-Solving Process') |
| |
|
| | # Level 1 nodes |
| | problem_identification = ThoughtNode('Problem Identification') |
| | problem_analysis = ThoughtNode('Problem Analysis') |
| | solution_generation = ThoughtNode('Solution Generation') |
| | implementation = ThoughtNode('Implementation') |
| | evaluation_adjustment = ThoughtNode('Evaluation and Adjustment') |
| |
|
| | root.add_child(problem_identification) |
| | root.add_child(problem_analysis) |
| | root.add_child(solution_generation) |
| | root.add_child(implementation) |
| | root.add_child(evaluation_adjustment) |
| |
|
| | # Problem Identification children |
| | B1 = ThoughtNode('Define the Problem') |
| | B2 = ThoughtNode('Identify Stakeholders') |
| | B3 = ThoughtNode('Determine Constraints') |
| | B4 = ThoughtNode('Recognize Problem Type') |
| | B5 = ThoughtNode('Historical Context') |
| | problem_identification.add_child(B1) |
| | problem_identification.add_child(B2) |
| | problem_identification.add_child(B3) |
| | problem_identification.add_child(B4) |
| | problem_identification.add_child(B5) |
| |
|
| | # Define the Problem children |
| | B1a = ThoughtNode('Problem Statement Formulation') |
| | B1b = ThoughtNode('Scope Definition') |
| | B1c = ThoughtNode('Objective Setting') |
| | B1.add_child(B1a) |
| | B1.add_child(B1b) |
| | B1.add_child(B1c) |
| |
|
| | # Identify Stakeholders children |
| | B2a = ThoughtNode('Stakeholder Mapping') |
| | B2b = ThoughtNode('Interest and Influence Analysis') |
| | B2c = ThoughtNode('Engagement Strategy') |
| | B2.add_child(B2a) |
| | B2.add_child(B2b) |
| | B2.add_child(B2c) |
| |
|
| | # Determine Constraints children |
| | B3a = ThoughtNode('Resource Limitations') |
| | B3b = ThoughtNode('Time Constraints') |
| | B3c = ThoughtNode('Legal and Regulatory Constraints') |
| | B3.add_child(B3a) |
| | B3.add_child(B3b) |
| | B3.add_child(B3c) |
| |
|
| | # Recognize Problem Type children |
| | B4a = ThoughtNode('Simple vs Complex') |
| | B4b = ThoughtNode('Known vs Unknown') |
| | B4c = ThoughtNode('Tame vs Wicked Problems') |
| | B4.add_child(B4a) |
| | B4.add_child(B4b) |
| | B4.add_child(B4c) |
| |
|
| | # Historical Context children |
| | B5a = ThoughtNode('Previous Attempts') |
| | B5b = ThoughtNode('Lessons Learned') |
| | B5c = ThoughtNode('Environmental Factors') |
| | B5.add_child(B5a) |
| | B5.add_child(B5b) |
| | B5.add_child(B5c) |
| |
|
| | # Problem Analysis children |
| | C1 = ThoughtNode('Root Cause Analysis') |
| | C2 = ThoughtNode('System Mapping') |
| | C3 = ThoughtNode('Data Collection') |
| | C4 = ThoughtNode('Impact Assessment') |
| | C5 = ThoughtNode('Theoretical Framework') |
| | problem_analysis.add_child(C1) |
| | problem_analysis.add_child(C2) |
| | problem_analysis.add_child(C3) |
| | problem_analysis.add_child(C4) |
| | problem_analysis.add_child(C5) |
| |
|
| | # Root Cause Analysis children |
| | C1a = ThoughtNode('5 Whys Technique') |
| | C1b = ThoughtNode('Fishbone Diagram') |
| | C1c = ThoughtNode('Pareto Analysis') |
| | C1.add_child(C1a) |
| | C1.add_child(C1b) |
| | C1.add_child(C1c) |
| |
|
| | # System Mapping children |
| | C2a = ThoughtNode('Causal Loop Diagrams') |
| | C2b = ThoughtNode('Stock and Flow Models') |
| | C2c = ThoughtNode('Network Analysis') |
| | C2.add_child(C2a) |
| | C2.add_child(C2b) |
| | C2.add_child(C2c) |
| |
|
| | # Data Collection children |
| | C3a = ThoughtNode('Quantitative Data') |
| | C3b = ThoughtNode('Qualitative Data') |
| | C3c = ThoughtNode('Data Validation') |
| | C3.add_child(C3a) |
| | C3.add_child(C3b) |
| | C3.add_child(C3c) |
| |
|
| | # Quantitative Data children |
| | C3a1 = ThoughtNode('Surveys and Questionnaires') |
| | C3a2 = ThoughtNode('Experimental Data') |
| | C3a3 = ThoughtNode('Big Data Analytics') |
| | C3a.add_child(C3a1) |
| | C3a.add_child(C3a2) |
| | C3a.add_child(C3a3) |
| |
|
| | # Qualitative Data children |
| | C3b1 = ThoughtNode('Interviews') |
| | C3b2 = ThoughtNode('Focus Groups') |
| | C3b3 = ThoughtNode('Observational Studies') |
| | C3b.add_child(C3b1) |
| | C3b.add_child(C3b2) |
| | C3b.add_child(C3b3) |
| |
|
| | # Data Validation children |
| | C3c1 = ThoughtNode('Statistical Validation') |
| | C3c2 = ThoughtNode('Cross-Validation') |
| | C3c3 = ThoughtNode('Expert Review') |
| | C3c.add_child(C3c1) |
| | C3c.add_child(C3c2) |
| | C3c.add_child(C3c3) |
| |
|
| | # Impact Assessment children |
| | C4a = ThoughtNode('Environmental Impact') |
| | C4b = ThoughtNode('Social Impact') |
| | C4c = ThoughtNode('Economic Impact') |
| | C4.add_child(C4a) |
| | C4.add_child(C4b) |
| | C4.add_child(C4c) |
| |
|
| | # Theoretical Framework children |
| | C5a = ThoughtNode('Literature Review') |
| | C5b = ThoughtNode('Conceptual Modeling') |
| | C5c = ThoughtNode('Hypothesis Formation') |
| | C5.add_child(C5a) |
| | C5.add_child(C5b) |
| | C5.add_child(C5c) |
| |
|
| | # Solution Generation children |
| | D1 = ThoughtNode('Creative Problem Solving') |
| | D2 = ThoughtNode('Analytical Approach') |
| | D3 = ThoughtNode('Mathematical Computation') |
| | D4 = ThoughtNode('Decision Making') |
| | solution_generation.add_child(D1) |
| | solution_generation.add_child(D2) |
| | solution_generation.add_child(D3) |
| | solution_generation.add_child(D4) |
| |
|
| | # Action Planning, Resource Allocation, Change Management children (implementation phase) |
| | E1 = ThoughtNode('Action Planning') |
| | E2 = ThoughtNode('Resource Allocation') |
| | E3 = ThoughtNode('Change Management') |
| | implementation.add_child(E1) |
| | implementation.add_child(E2) |
| | implementation.add_child(E3) |
| |
|
| | # Verification, Performance Metrics, Feedback Loops, Continuous Improvement children (evaluation phase) |
| | F1 = ThoughtNode('Verification') |
| | F2 = ThoughtNode('Performance Metrics') |
| | F3 = ThoughtNode('Feedback Loops') |
| | F4 = ThoughtNode('Continuous Improvement') |
| | evaluation_adjustment.add_child(F1) |
| | evaluation_adjustment.add_child(F2) |
| | evaluation_adjustment.add_child(F3) |
| | evaluation_adjustment.add_child(F4) |
| |
|
| | # Cross-Cutting Considerations children |
| | G = ThoughtNode('Cross-Cutting Considerations') |
| | root.add_child(G) |
| |
|
| | # Cross-Cutting Considerations children |
| | G1 = ThoughtNode('Ethical Framework') |
| | G2 = ThoughtNode('Stakeholder Management') |
| | G3 = ThoughtNode('Interdisciplinary Connections') |
| | G4 = ThoughtNode('Technological Integration') |
| | G5 = ThoughtNode('Emotional Intelligence') |
| | G6 = ThoughtNode('Collaborative Problem Solving') |
| | G7 = ThoughtNode('Computational Considerations') # Assuming H was intended as G7 |
| | G8 = ThoughtNode('Order of Operations') # Assuming I was intended as G8 |
| | G9 = ThoughtNode('Critical Thinking') # Assuming J was intended as G9 |
| | G10 = ThoughtNode('Future Perspective') # Assuming K was intended as G10 |
| | G11 = ThoughtNode('Learning and Adaptation') # Assuming L was intended as G11 |
| | G.add_child(G1) |
| | G.add_child(G2) |
| | G.add_child(G3) |
| | G.add_child(G4) |
| | G.add_child(G5) |
| | G.add_child(G6) |
| | G.add_child(G7) |
| | G.add_child(G8) |
| | G.add_child(G9) |
| | G.add_child(G10) |
| | G.add_child(G11) |
| |
|
| | # Ethical Framework children |
| | G1a = ThoughtNode('Value-based Decision Making') |
| | G1b = ThoughtNode('Long-term Consequences') |
| | G1.add_child(G1a) |
| | G1.add_child(G1b) |
| |
|
| | # Value-based Decision Making children |
| | G1a1 = ThoughtNode('Ethical Theories Application') |
| | G1a2 = ThoughtNode('Moral Dilemma Resolution') |
| | G1a.add_child(G1a1) |
| | G1a.add_child(G1a2) |
| |
|
| | # Long-term Consequences children |
| | G1b1 = ThoughtNode('Sustainability Assessment') |
| | G1b2 = ThoughtNode('Intergenerational Impact') |
| | G1b.add_child(G1b1) |
| | G1b.add_child(G1b2) |
| |
|
| | # Stakeholder Management children |
| | G2a = ThoughtNode('Direct Stakeholders') |
| | G2b = ThoughtNode('Indirect Stakeholders') |
| | G2c = ThoughtNode('Conflicting Interests') |
| | G2.add_child(G2a) |
| | G2.add_child(G2b) |
| | G2.add_child(G2c) |
| |
|
| | # Conflicting Interests children |
| | G2c1 = ThoughtNode('Negotiation Strategies') |
| | G2c2 = ThoughtNode('Conflict Resolution Techniques') |
| | G2c.add_child(G2c1) |
| | G2c.add_child(G2c2) |
| |
|
| | # Interdisciplinary Connections children |
| | G3a = ThoughtNode('Related Fields') |
| | G3b = ThoughtNode('Cross-disciplinary Impact') |
| | G3.add_child(G3a) |
| | G3.add_child(G3b) |
| |
|
| | # Related Fields children |
| | G3a1 = ThoughtNode('Cross-domain Knowledge Transfer') |
| | G3a2 = ThoughtNode('Interdisciplinary Collaboration') |
| | G3a.add_child(G3a1) |
| | G3a.add_child(G3a2) |
| |
|
| | # Cross-disciplinary Impact children |
| | G3b1 = ThoughtNode('Synergy Identification') |
| | G3b2 = ThoughtNode('Holistic Impact Assessment') |
| | G3b.add_child(G3b1) |
| | G3b.add_child(G3b2) |
| |
|
| | # Technological Integration children |
| | G4a = ThoughtNode('AI-assisted Problem Solving') |
| | G4b = ThoughtNode('Data-driven Insights') |
| | G4c = ThoughtNode('Digital Collaboration Tools') |
| | G4.add_child(G4a) |
| | G4.add_child(G4b) |
| | G4.add_child(G4c) |
| |
|
| | # AI-assisted Problem Solving children |
| | G4a1 = ThoughtNode('Machine Learning Models') |
| | G4a2 = ThoughtNode('Natural Language Processing') |
| | G4a.add_child(G4a1) |
| | G4a.add_child(G4a2) |
| |
|
| | # Data-driven Insights children |
| | G4b1 = ThoughtNode('Big Data Analytics') |
| | G4b2 = ThoughtNode('Predictive Modeling') |
| | G4b.add_child(G4b1) |
| | G4b.add_child(G4b2) |
| |
|
| | # Digital Collaboration Tools children |
| | G4c1 = ThoughtNode('Project Management Platforms') |
| | G4c2 = ThoughtNode('Virtual Reality Collaboration') |
| | G4c.add_child(G4c1) |
| | G4c.add_child(G4c2) |
| |
|
| | # Emotional Intelligence children |
| | G5a = ThoughtNode('Self-Awareness') |
| | G5b = ThoughtNode('Empathy') |
| | G5c = ThoughtNode('Stress Management') |
| | G5.add_child(G5a) |
| | G5.add_child(G5b) |
| | G5.add_child(G5c) |
| |
|
| | # Self-Awareness children |
| | G5a1 = ThoughtNode('Emotional Recognition') |
| | G5a2 = ThoughtNode('Personal Bias Identification') |
| | G5a.add_child(G5a1) |
| | G5a.add_child(G5a2) |
| |
|
| | # Empathy children |
| | G5b1 = ThoughtNode('Perspective Taking') |
| | G5b2 = ThoughtNode('Active Listening') |
| | G5b.add_child(G5b1) |
| | G5b.add_child(G5b2) |
| |
|
| | # Stress Management children |
| | G5c1 = ThoughtNode('Mindfulness Techniques') |
| | G5c2 = ThoughtNode('Resilience Building') |
| | G5c.add_child(G5c1) |
| | G5c.add_child(G5c2) |
| |
|
| | # Collaborative Problem Solving children |
| | G6a = ThoughtNode('Team Dynamics') |
| | G6b = ThoughtNode('Communication Strategies') |
| | G6c = ThoughtNode('Conflict Resolution') |
| | G6.add_child(G6a) |
| | G6.add_child(G6b) |
| | G6.add_child(G6c) |
| |
|
| | # Team Dynamics children |
| | G6a1 = ThoughtNode('Team Formation Strategies') |
| | G6a2 = ThoughtNode('Role Assignment') |
| | G6a.add_child(G6a1) |
| | G6a.add_child(G6a2) |
| |
|
| | # Communication Strategies children |
| | G6b1 = ThoughtNode('Clear Messaging') |
| | G6b2 = ThoughtNode('Feedback Mechanisms') |
| | G6b.add_child(G6b1) |
| | G6b.add_child(G6b2) |
| |
|
| | # Conflict Resolution children |
| | G6c1 = ThoughtNode('Mediation Techniques') |
| | G6c2 = ThoughtNode('Consensus Building') |
| | G6c.add_child(G6c1) |
| | G6c.add_child(G6c2) |
| |
|
| | # Computational Considerations children |
| | G7a = ThoughtNode('CPU Operations') |
| | G7b = ThoughtNode('GPU Parallelization') |
| | G7c = ThoughtNode('Floating-Point Precision') |
| | G7.add_child(G7a) |
| | G7.add_child(G7b) |
| | G7.add_child(G7c) |
| |
|
| | # CPU Operations children |
| | G7a1 = ThoughtNode('Instruction Set Architecture') |
| | G7a2 = ThoughtNode('Pipelining and Parallelism') |
| | G7a.add_child(G7a1) |
| | G7a.add_child(G7a2) |
| |
|
| | # GPU Parallelization children |
| | G7b1 = ThoughtNode('CUDA Programming') |
| | G7b2 = ThoughtNode('OpenCL Framework') |
| | G7b.add_child(G7b1) |
| | G7b.add_child(G7b2) |
| |
|
| | # Floating-Point Precision children |
| | G7c1 = ThoughtNode('IEEE 754 Standard') |
| | G7c2 = ThoughtNode('Error Propagation Analysis') |
| | G7c.add_child(G7c1) |
| | G7c.add_child(G7c2) |
| |
|
| | # Order of Operations children |
| | G8a = ThoughtNode('Parentheses') |
| | G8b = ThoughtNode('Exponents') |
| | G8c = ThoughtNode('Multiplication and Division') |
| | G8d = ThoughtNode('Addition and Subtraction') |
| | G8.add_child(G8a) |
| | G8.add_child(G8b) |
| | G8.add_child(G8c) |
| | G8.add_child(G8d) |
| |
|
| | # Critical Thinking children |
| | G9a = ThoughtNode('Assumptions Questioning') |
| | G9b = ThoughtNode('Bias Recognition') |
| | G9.add_child(G9a) |
| | G9.add_child(G9b) |
| |
|
| | # Assumptions Questioning children |
| | G9a1 = ThoughtNode('Socratic Questioning') |
| | G9a2 = ThoughtNode('Devil\'s Advocate Approach') |
| | G9a.add_child(G9a1) |
| | G9a.add_child(G9a2) |
| |
|
| | # Bias Recognition children |
| | G9b1 = ThoughtNode('Cognitive Bias Identification') |
| | G9b2 = ThoughtNode('Debiasing Techniques') |
| | G9b.add_child(G9b1) |
| | G9b.add_child(G9b2) |
| |
|
| | # Future Perspective children |
| | G10a = ThoughtNode('Short-term Projections') |
| | G10b = ThoughtNode('Long-term Scenarios') |
| | G10c = ThoughtNode('Potential Impacts') |
| | G10.add_child(G10a) |
| | G10.add_child(G10b) |
| | G10.add_child(G10c) |
| |
|
| | # Short-term Projections children |
| | G10a1 = ThoughtNode('Trend Analysis') |
| | G10a2 = ThoughtNode('Scenario Planning') |
| | G10a.add_child(G10a1) |
| | G10a.add_child(G10a2) |
| |
|
| | # Long-term Scenarios children |
| | G10b1 = ThoughtNode('Futures Wheel') |
| | G10b2 = ThoughtNode('Backcasting') |
| | G10b.add_child(G10b1) |
| | G10b.add_child(G10b2) |
| |
|
| | # Potential Impacts children |
| | G10c1 = ThoughtNode('Risk Assessment') |
| | G10c2 = ThoughtNode('Opportunity Identification') |
| | G10c.add_child(G10c1) |
| | G10c.add_child(G10c2) |
| |
|
| | # Learning and Adaptation children |
| | G11a = ThoughtNode('Reflective Practice') |
| | G11b = ThoughtNode('Knowledge Transfer') |
| | G11c = ThoughtNode('Adaptive Problem Solving') |
| | G11.add_child(G11a) |
| | G11.add_child(G11b) |
| | G11.add_child(G11c) |
| |
|
| | # Reflective Practice children |
| | G11a1 = ThoughtNode('After Action Review') |
| | G11a2 = ThoughtNode('Learning Journals') |
| | G11a.add_child(G11a1) |
| | G11a.add_child(G11a2) |
| |
|
| | # Knowledge Transfer children |
| | G11b1 = ThoughtNode('Best Practice Documentation') |
| | G11b2 = ThoughtNode('Mentoring Programs') |
| | G11b.add_child(G11b1) |
| | G11b.add_child(G11b2) |
| |
|
| | # Adaptive Problem Solving children |
| | G11c1 = ThoughtNode('Iterative Approaches') |
| | G11c2 = ThoughtNode('Flexibility in Methodology') |
| | G11c.add_child(G11c1) |
| | G11c.add_child(G11c2) |
| |
|
| | return root |
| |
|
| | def traverse_tree(node, action_list): |
| | if node.name not in action_list: |
| | action_list.append(node.name) |
| | for child in node.children: |
| | traverse_tree(child, action_list) |
| |
|
| |
|
| |
|
| | def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414): |
| |
|
| |
|
| | """ |
| | Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search. |
| | |
| | Args: |
| | query (str): The input query or prompt. |
| | world_model_components (tuple): Tuple containing the model components. |
| | root_thought_node (ThoughtNode): The root node of the Tree of Thought. |
| | tokenizer (transformers.PreTrainedTokenizer): The tokenizer used. |
| | max_length (int): Maximum length for the generated sequence. |
| | inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought') |
| | beam_size (int): Size of the beam for beam search |
| | n_tokens_predict (int): Number of tokens to predict at each step |
| | |
| | Returns: |
| | List[str] or str: The sequence of actions (thoughts) selected or generated text. |
| | """ |
| | representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
| |
|
| | # Tokenize and encode the query |
| | input_ids = tokenizer.encode(query, return_tensors='pt').to(device) |
| | attention_mask = (input_ids != tokenizer.pad_token_id).long() |
| |
|
| | if inference_mode == 'without_world_model': |
| | # Directly use the transformer model to generate text with beam search |
| | with torch.no_grad(): |
| | generated_sequences = model_transformer.generate_with_beam_search( |
| | src=input_ids, |
| | tokenizer=tokenizer, |
| | beam_size=beam_size, |
| | max_length=max_length, |
| | n_tokens_predict=n_tokens_predict, |
| | temperature=args.temperature |
| | ) |
| | best_sequence, best_score = generated_sequences[0] |
| | generated_text = tokenizer.decode(best_sequence[0], skip_special_tokens=True) |
| | return generated_text |
| |
|
| | else: |
| | # Use the world model components |
| | with torch.no_grad(): |
| | transformer_output = model_transformer(input_ids, input_ids) |
| | # Get the initial state representation |
| | initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim) |
| | initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim) |
| | initial_state = State( |
| | representation=initial_representation, |
| | dynamics_network=dynamics_network, |
| | action_encoder=action_encoder, |
| | thought_node=root_thought_node |
| | ) |
| | if inference_mode == 'world_model_tree_of_thought': |
| | # Use MCTS with Tree of Thought and multi-token beam search |
| | mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant) |
| |
|
| | current_state = initial_state |
| | thought_sequence = [] |
| |
|
| | for _ in range(max_length |
| | best_actions = mcts.search_with_beam(current_state) |
| |
|
| | thought_sequence.extend(best_actions) |
| |
|
| | # Apply the best actions to get the next state |
| | for action in best_actions: |
| | current_state = current_state.apply_action(action) |
| |
|
| | # Check if we've reached a leaf node (no further actions) |
| | if len(current_state.thought_node.children) == 0: |
| | break |
| |
|
| | return thought_sequence |
| | else: |
| | # Use the world model without Tree of Thought, but with multi-token beam search |
| | beam = [(initial_state, 0.0, torch.zeros(1, device=device), torch.zeros(1, device=device))] # (state, score, cum_entropy, cum_variance) |
| |
|
| | for _ in range(max_length |
| | all_candidates = [] |
| | for state, score, cum_entropy, cum_variance in beam: |
| | policy_logits, _ = prediction_network(state.representation) |
| | probs = F.softmax(policy_logits / args.temperature, dim=-1) |
| | entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| | variance = torch.var(probs, dim=-1) |
| |
|
| | topk_probs, topk_indices = torch.topk(probs, k=beam_size, dim=-1) |
| |
|
| | for i in range(beam_size ** n_tokens_predict): |
| | indices = [i // (beam_size ** j) % beam_size for j in range(n_tokens_predict)] |
| | new_actions = [index_to_action[topk_indices[0, j, indices[j]].item()] for j in range(n_tokens_predict)] |
| | new_score = score + torch.sum(torch.log(topk_probs[0, range(n_tokens_predict), indices])) |
| | new_entropy = cum_entropy + torch.sum(entropy[0, indices]) |
| | new_variance = cum_variance + torch.sum(variance[0, indices]) |
| |
|
| | new_state = state |
| | for action in new_actions: |
| | new_state = new_state.apply_action(action) |
| |
|
| | all_candidates.append((new_state, new_score, new_entropy, new_variance, new_actions)) |
| |
|
| | # Select top beam_size candidates |
| | beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:beam_size] |
| |
|
| | # Accumulate actions |
| | if not thought_sequence: |
| | thought_sequence = [b[4] for b in beam] |
| | else: |
| | for i, b in enumerate(beam): |
| | thought_sequence[i].extend(b[4]) |
| |
|
| | # Return the top sequence |
| | return thought_sequence[0] |
| |
|
| |
|
| | def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
| | representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components |
| | representation_network.train() |
| | dynamics_network.train() |
| | prediction_network.train() |
| | action_encoder.train() |
| | ppo_agent.policy_network.train() |
| |
|
| | total_loss = 0.0 |
| | optimizer.zero_grad() |
| | print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
| |
|
| | for i, batch in enumerate(train_loader): |
| | print(f"Processing batch {i+1}/{len(train_loader)}...") |
| |
|
| | # Move batches to the device |
| | src_batch = batch['input_ids'].to(device) |
| | tgt_batch = batch['labels'].to(device) |
| |
|
| | with torch.amp.autocast(device_type='cuda'): |
| | print("Forward pass through Transformer (frozen)...") |
| | with torch.no_grad(): |
| | transformer_output = model_transformer(src_batch, tgt_batch[:, :-1]) |
| |
|
| | # World Model - Representation |
| | state_representation = representation_network(transformer_output) |
| |
|
| | # For simplicity, let's assume true actions are provided (e.g., next tokens) |
| | true_actions = tgt_batch[:, :-1] |
| | action_sequences = true_actions |
| |
|
| | # Get action embeddings |
| | action_embeddings = action_encoder(action_sequences) |
| |
|
| | # Apply dynamics network |
| | predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
| |
|
| | # Prediction Network - Policy logits and value |
| | policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
| |
|
| | # Define true_policy and true_value as placeholders on the GPU |
| | true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
| | true_value = torch.zeros_like(value_estimates).to(device) |
| |
|
| | # Compute individual losses |
| | ppo_loss = ppo_agent.compute_loss( |
| | state_representation, |
| | torch.zeros_like(true_actions, dtype=torch.float32).to(device), |
| | true_actions, |
| | torch.zeros_like(value_estimates, dtype=torch.float32).to(device), |
| | torch.zeros_like(value_estimates, dtype=torch.float32).to(device) |
| | ) |
| |
|
| | info_nce = InfoNCE_Loss()( |
| | state_representation.reshape(-1, state_dim), |
| | F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True) |
| | ) |
| |
|
| |
|
| | covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
| | dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
| |
|
| | perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
| | thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
| |
|
| | pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
| | action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
| |
|
| | mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
| | etv = ExpectedThoughtValueLoss()(mcts_best_values) |
| |
|
| | visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
| | exploration = ExplorationRegularization()(visit_counts) |
| |
|
| | old_policy = F.softmax(policy_logits.detach(), dim=-1) |
| | new_policy = F.softmax(policy_logits, dim=-1) |
| | kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
| |
|
| | # Total Loss |
| | loss = ( |
| | ppo_loss + |
| | info_nce + |
| | covariance + |
| | dynamics_loss + |
| | thought_loss + |
| | pv_loss + |
| | action_diversity + |
| | etv + |
| | exploration + |
| | kl_loss |
| | ) |
| | loss = loss / args.accumulation_steps |
| |
|
| | print("Backward pass...") |
| | scaler.scale(loss).backward() |
| |
|
| | if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| | print("Gradient clipping...") |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_( |
| | [param for group in optimizer.param_groups for param in group['params']], |
| | args.max_grad_norm |
| | ) |
| |
|
| | print("Optimizer step...") |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | print("Zeroing gradients...") |
| | optimizer.zero_grad() |
| |
|
| | print("Updating learning rate...") |
| | scheduler.step() |
| |
|
| | total_loss += loss.item() * args.accumulation_steps |
| |
|
| | # Print individual losses and total loss for this batch |
| | print(f"Batch {i+1} completed. Losses:") |
| | print(f" PPO Loss: {ppo_loss.item():.4f}") |
| | print(f" InfoNCE Loss: {info_nce.item():.4f}") |
| | print(f" Covariance Loss: {covariance.item():.4f}") |
| | print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
| | print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
| | print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
| | print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
| | print(f" Expected Thought Value Loss: {etv.item():.4f}") |
| | print(f" Exploration Loss: {exploration.item():.4f}") |
| | print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
| | print(f" Total Loss: {loss.item():.4f}") |
| |
|
| | avg_loss = total_loss / len(train_loader) |
| | print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
| | return avg_loss |
| |
|
| | def train_epoch_language_model(model, train_loader, optimizer, scheduler, scaler, args): |
| | model.train() |
| | total_loss = 0.0 |
| | optimizer.zero_grad() |
| | print(f"Starting Language Model training epoch with {len(train_loader)} batches...") |
| |
|
| | for i, batch in enumerate(train_loader): |
| | input_ids = batch['input_ids'].to(device) |
| | labels = batch['labels'].to(device) |
| |
|
| | with autocast(): |
| | outputs = model(input_ids, input_ids) |
| | logits = outputs.view(-1, outputs.size(-1)) |
| | labels = labels.view(-1) |
| | loss = F.cross_entropy(logits, labels, ignore_index=model.embedding.padding_idx) |
| | loss = loss / args.accumulation_steps |
| |
|
| | scaler.scale(loss).backward() |
| |
|
| | if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_( |
| | [param for group in optimizer.param_groups for param in group['params']], |
| | args.max_grad_norm |
| | ) |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad() |
| | scheduler.step() |
| |
|
| | total_loss += loss.item() * args.accumulation_steps |
| | print(f"Batch {i + 1} completed. Current loss: {loss.item():.4f}") |
| |
|
| | avg_loss = total_loss / len(train_loader) |
| | print(f"Language Model training epoch completed. Average loss: {avg_loss:.4f}") |
| | return avg_loss |
| |
|
| |
|
| | def train_custom_data_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
| | representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components |
| | representation_network.train() |
| | dynamics_network.train() |
| | prediction_network.train() |
| | action_encoder.train() |
| | ppo_agent.policy_network.train() |
| |
|
| | total_loss = 0.0 |
| | optimizer.zero_grad() |
| | print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
| |
|
| | for i, batch in enumerate(train_loader): |
| | print(f"Processing batch {i+1}/{len(train_loader)}...") |
| |
|
| | # Move batches to the device |
| | input_ids = batch['input_ids'].to(device) |
| | attention_mask = batch['attention_mask'].to(device) |
| | episode_reward = batch['episode_reward'].to(device) |
| | loss_value = batch['loss'].to(device) |
| | cosine_similarity = batch['cosine_similarity'].to(device) |
| | rag_performance = batch['rag_performance'].to(device) |
| | ranking_model_performance = batch['ranking_model_performance'].to(device) |
| |
|
| | with torch.amp.autocast(device_type='cuda'): |
| | print("Forward pass through Transformer (frozen)...") |
| | with torch.no_grad(): |
| | transformer_output = model_transformer(input_ids, input_ids) |
| |
|
| | # World Model - Representation |
| | state_representation = representation_network(transformer_output) |
| | print(f"State representation shape: {state_representation.shape}") |
| |
|
| | # For simplicity, let's assume true actions are provided (e.g., next tokens) |
| | true_actions = input_ids[:, 1:] # Shift input_ids by 1 to get next tokens |
| | print(f"True actions shape: {true_actions.shape}") |
| | action_sequences = true_actions |
| |
|
| | # Get action embeddings |
| | action_embeddings = action_encoder(action_sequences) |
| | print(f"Action embeddings shape: {action_embeddings.shape}") |
| |
|
| | # Ensure state_representation and action_embeddings have the same sequence length |
| | min_seq_len = min(state_representation.size(1), action_embeddings.size(1)) |
| | state_representation = state_representation[:, :min_seq_len, :] |
| | action_embeddings = action_embeddings[:, :min_seq_len, :] |
| |
|
| | print(f"Adjusted state representation shape: {state_representation.shape}") |
| | print(f"Adjusted action embeddings shape: {action_embeddings.shape}") |
| |
|
| | # Apply dynamics network |
| | predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
| | print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}") |
| |
|
| | # Prediction Network - Policy logits and value |
| | policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
| |
|
| | # Adjust true_actions to match the sequence length |
| | true_actions = true_actions[:, :min_seq_len] |
| |
|
| | # Define true_policy and true_value |
| | true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
| | true_value = episode_reward.unsqueeze(1).expand(-1, min_seq_len) # Expand to match sequence length |
| |
|
| | # Compute individual losses |
| | info_nce = InfoNCE_Loss()( |
| | state_representation.reshape(-1, state_dim), |
| | F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True) |
| | ) |
| |
|
| | covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
| | dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
| |
|
| | perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
| | thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
| |
|
| | pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
| | action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
| |
|
| | mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
| | etv = ExpectedThoughtValueLoss()(mcts_best_values) |
| |
|
| | visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
| | exploration = ExplorationRegularization()(visit_counts) |
| |
|
| | old_policy = F.softmax(policy_logits.detach(), dim=-1) |
| | new_policy = F.softmax(policy_logits, dim=-1) |
| | kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
| |
|
| | # Compute mean value estimates over the sequence length |
| | value_estimates_mean = value_estimates.squeeze(-1).mean(dim=1) # Shape: [batch_size] |
| |
|
| | # Add new loss components |
| | rag_loss = F.mse_loss(value_estimates_mean, rag_performance) |
| | ranking_loss = F.mse_loss(value_estimates_mean, ranking_model_performance) |
| | cosine_similarity_loss = 1 - cosine_similarity.mean() # Maximize cosine similarity |
| |
|
| | # Total Loss |
| | loss = ( |
| | info_nce + |
| | covariance + |
| | dynamics_loss + |
| | thought_loss + |
| | pv_loss + |
| | action_diversity + |
| | etv + |
| | exploration + |
| | kl_loss + |
| | rag_loss + |
| | ranking_loss + |
| | cosine_similarity_loss |
| | ) |
| | loss = loss / args.accumulation_steps |
| |
|
| | print("Backward pass...") |
| | scaler.scale(loss).backward() |
| |
|
| | if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| | print("Gradient clipping...") |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_( |
| | [param for group in optimizer.param_groups for param in group['params']], |
| | args.max_grad_norm |
| | ) |
| |
|
| | print("Optimizer step...") |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | print("Zeroing gradients...") |
| | optimizer.zero_grad() |
| |
|
| | print("Updating learning rate...") |
| | scheduler.step() |
| |
|
| | # Print individual losses and total loss for this batch |
| | print(f"Batch {i+1} completed. Losses:") |
| | print(f" InfoNCE Loss: {info_nce.item():.4f}") |
| | print(f" Covariance Loss: {covariance.item():.4f}") |
| | print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
| | print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
| | print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
| | print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
| | print(f" Expected Thought Value Loss: {etv.item():.4f}") |
| | print(f" Exploration Loss: {exploration.item():.4f}") |
| | print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
| | print(f" RAG Loss: {rag_loss.item():.4f}") |
| | print(f" Ranking Loss: {ranking_loss.item():.4f}") |
| | print(f" Cosine Similarity Loss: {cosine_similarity_loss.item():.4f}") |
| | print(f" Total Loss: {loss.item():.4f}") |
| |
|
| | avg_loss = total_loss / len(train_loader) |
| | print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
| | return avg_loss |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | print("Arguments parsed successfully.") |
| |
|
| | # Create save directory |
| | os.makedirs(args.save_dir, exist_ok=True) |
| | print(f"Save directory created: {args.save_dir}") |
| |
|
| | # Load tokenizer |
| | print("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | print("Tokenizer loaded successfully.") |
| |
|
| | # Define padding_idx and input dimension based on tokenizer |
| | padding_idx = tokenizer.pad_token_id |
| | input_dim = len(tokenizer) |
| |
|
| |
|
| | # Initialize the Transformer model on GPU |
| | print("Initializing Transformer model...") |
| | model_transformer = Transformer( |
| | input_dim=input_dim, |
| | d_model=128, |
| | num_heads=4, |
| | num_layers=4, |
| | d_ff=256, |
| | num_experts=2, |
| | output_dim=input_dim, |
| | dropout=0.1, |
| | top_k=2 |
| | ).to(device) |
| | model_transformer.train() |
| | print("Transformer model initialized on device.") |
| |
|
| | # Define model parameters (adjusted for speed) |
| | d_model = 32 |
| | state_dim = 32 |
| | action_dim = d_model |
| | hidden_dim = 64 |
| | vocab_dim = input_dim |
| | embed_dim = d_model |
| |
|
| | # Define World Model components |
| | representation_network = RepresentationNetwork(vocab_dim, d_model, state_dim).to(device) |
| | dynamics_network = DynamicsNetwork(state_dim, action_dim, hidden_dim).to(device) |
| | prediction_network = PredictionNetwork(state_dim, input_dim, 1).to(device) |
| | action_encoder = ActionEncoder(input_dim, action_dim).to(device) |
| |
|
| | # Initialize PPO Agent |
| | ppo_agent = PPOAgent( |
| | policy_network=prediction_network, |
| | optimizer=optim.AdamW(prediction_network.parameters(), lr=args.learning_rate), |
| | clip_epsilon=0.2, |
| | entropy_coef=0.01, |
| | value_coef=0.5 |
| | ) |
| |
|
| | # Bundle World Model components |
| | world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer) |
| |
|
| | print(f"Current mode: {args.mode}") |
| | if args.mode == 'train': |
| | print("Loading and preprocessing data...") |
| | if args.use_custom_data: |
| | custom_data = load_custom_data_from_files(args.custom_data_paths) |
| | processed_data = preprocess_custom_data(custom_data) |
| | train_loader, eval_loader = load_custom_data(args, tokenizer, processed_data) |
| | print("Custom data loaded and preprocessed successfully.") |
| | else: |
| | train_loader, eval_loader = load_data(args, tokenizer) |
| | print("Default data loaded and preprocessed successfully.") |
| |
|
| | # Optimizer and Scheduler |
| | optimizer = optim.AdamW( |
| | list(representation_network.parameters()) + |
| | list(dynamics_network.parameters()) + |
| | list(prediction_network.parameters()) + |
| | list(action_encoder.parameters()), |
| | lr=args.learning_rate, weight_decay=args.weight_decay |
| | ) if args.train_mode == 'world_model' else optim.AdamW(model_transformer.parameters(), lr=args.learning_rate) |
| | scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs) |
| | scaler = GradScaler() |
| |
|
| | print(f"Starting {args.train_mode} training...") |
| |
|
| | for epoch in range(args.num_epochs): |
| | if args.train_mode == 'world_model': |
| | if args.use_custom_data: |
| | avg_loss = train_custom_data_epoch_world_model( |
| | world_model_components, |
| | train_loader, |
| | optimizer, |
| | scheduler, |
| | scaler, |
| | args, |
| | model_transformer, |
| | state_dim, |
| | embed_dim, |
| | input_dim |
| | ) |
| | else: |
| | avg_loss = train_epoch_world_model( |
| | world_model_components, |
| | train_loader, |
| | optimizer, |
| | scheduler, |
| | scaler, |
| | args, |
| | model_transformer, |
| | state_dim, |
| | embed_dim, |
| | input_dim |
| | ) |
| | else: |
| | avg_loss = train_epoch_language_model( |
| | model_transformer, |
| | train_loader, |
| | optimizer, |
| | scheduler, |
| | scaler, |
| | args |
| | ) |
| |
|
| | print(f"{args.train_mode.capitalize()} training epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") |
| |
|
| | # Save models |
| | if args.train_mode == 'world_model': |
| | save_all_models(model_transformer, representation_network, dynamics_network, prediction_network, action_encoder, args.save_dir, epoch + 1) |
| | print(f"Models saved for epoch {epoch + 1}") |
| | else: |
| | torch.save(model_transformer.state_dict(), os.path.join(args.save_dir, f'language_model_epoch_{epoch + 1}.pt')) |
| | print(f"Language model saved for epoch {epoch + 1}") |
| |
|
| | print("Training completed.") |
| |
|
| | elif args.mode == 'inference': |
| | print("Entering inference mode...") |
| | # Build Tree of Thought if needed |
| | print("Building Tree of Thought...") |
| | tree_root = build_tree_of_thought() |
| | print("Tree of Thought built successfully.") |
| |
|
| | # Generate action list |
| | print("Generating action list...") |
| | action_list = [] |
| | traverse_tree(tree_root, action_list) |
| | print(f"Action list generated. Total actions: {len(action_list)}") |
| |
|
| | # Create mappings |
| | global action_to_index, index_to_action |
| | action_to_index = {action: idx for idx, action in enumerate(action_list)} |
| | index_to_action = {idx: action for action, idx in action_to_index.items()} |
| | action_vocab_size = len(action_list) |
| | print(f"Action mappings created. Vocabulary size: {action_vocab_size}") |
| |
|
| | # Initialize or load models based on the load_model argument |
| | if args.load_model: |
| | print(f"Loading saved model from {args.load_model}") |
| | # Load the saved models |
| | model_transformer.load_state_dict(torch.load(os.path.join(args.load_model, 'transformer_model.pt'))) |
| | representation_network.load_state_dict(torch.load(os.path.join(args.load_model, 'representation_network.pt'))) |
| | dynamics_network.load_state_dict(torch.load(os.path.join(args.load_model, 'dynamics_network.pt'))) |
| |
|
| | # Load prediction network and adjust its size if necessary |
| | saved_state_dict = torch.load(os.path.join(args.load_model, 'prediction_network.pt')) |
| | saved_vocab_size = saved_state_dict['policy_head.weight'].size(0) |
| | if saved_vocab_size != action_vocab_size: |
| | print(f"Adjusting prediction network size from {saved_vocab_size} to {action_vocab_size}") |
| | prediction_network = PredictionNetwork(state_dim, saved_vocab_size, 1).to(device) |
| | prediction_network.load_state_dict(saved_state_dict) |
| | prediction_network.policy_head = nn.Linear(prediction_network.state_dim, action_vocab_size).to(device) |
| | else: |
| | prediction_network = PredictionNetwork(state_dim, action_vocab_size, 1).to(device) |
| | prediction_network.load_state_dict(saved_state_dict) |
| |
|
| | action_encoder.load_state_dict(torch.load(os.path.join(args.load_model, 'action_encoder.pt'))) |
| | else: |
| | print("Using newly initialized models") |
| |
|
| | # Prepare the components |
| | world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer) |
| |
|
| | print("Starting inference loop...") |
| | while True: |
| | if args.query: |
| | query = args.query |
| | args.query = None # Reset query for next iteration |
| | else: |
| | query = input("Please enter your query (or type 'exit' to quit): ") |
| | if query.lower() == 'exit': |
| | break |
| |
|
| | print(f"Processing query: {query}") |
| | result = infer(query, world_model_components, tree_root, tokenizer, |
| | max_length=args.max_length, |
| | inference_mode=args.inference_mode, |
| | beam_size=args.beam_size, |
| | n_tokens_predict=args.n_tokens_predict, |
| | mcts_iterations=args.mcts_iterations, |
| | exploration_constant=args.mcts_exploration_constant) |
| |
|
| |
|
| | if args.inference_mode == 'without_world_model': |
| | print("Generated Text:") |
| | print(result) |
| | else: |
| | print("Generated Thought Sequence:") |
| | for thought in result: |
| | print(thought) |
| |
|
| | print("\n") # Add a newline for better readability between queries |
| |
|
| | print("Inference completed.") |
| |
|
| | else: |
| | print(f"Invalid mode: {args.mode}. Please choose 'train' or 'inference'.") |
| | if __name__ == '__main__': |
| | sys.argv = [ |
| | 'lightbulb_2.py', |
| | '--mode', 'inference', |
| | '--train_mode', 'world_model', # Set 'world_model' or 'language_model' depending on the training mode |
| | '--dataset_name', 'wikitext', # Specify the Hugging Face dataset (e.g., 'wikitext') |
| | '--dataset_config', 'wikitext-2-raw-v1', # Use if you need a specific config of the dataset |
| | '--num_epochs', '10', |
| | '--batch_size', '4', |
| | '--accumulation_steps', '1', |
| | '--max_grad_norm', '1.0', |
| | '--weight_decay', '0.01', |
| | '--learning_rate', '1e-4', |
| | '--max_length', '512', |
| | '--save_dir', './trained_models', |
| | # Uncomment the following line to use custom data instead of a Hugging Face dataset |
| | #'--use_custom_data', |
| | '--custom_data_paths', '/content/drive/MyDrive/lightbulb/knowledge_base.json', |
| | '--custom_data_paths', '/content/drive/MyDrive/lightbulb/rag_cache.json', |
| | '--custom_data_paths', '/content/drive/MyDrive/lightbulb/llm_training_data/llm_training_data.jsonl' |
| | ] |
| |
|
| | # Parse the arguments and run the main training function |
| | args = parse_args() |
| |
|
| | # Check which data source to use |
| | if args.use_custom_data: |
| | print("Training with custom data from paths:") |
| | for path in args.custom_data_paths: |
| | print(f" - {path}") |
| | else: |
| | print(f"Training with dataset '{args.dataset_name}' from Hugging Face Datasets") |
| |
|
| | main() |
| |
|
| |
|