File size: 9,837 Bytes
4b1fd1d
 
 
d0928c6
4b1fd1d
d0928c6
 
4b1fd1d
 
 
d0928c6
4b1fd1d
 
d0928c6
4b1fd1d
 
d0928c6
 
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
 
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0928c6
4b1fd1d
 
 
 
d0928c6
4b1fd1d
 
 
 
 
 
 
 
d0928c6
 
 
4b1fd1d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""

Train a new Wildnerve model with parameters loaded from config.json.

"""
import os
import sys
import torch
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple

# Import configuration
from config import app_config, get_model_architecture_params

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def train_model(

    specialization: str,

    dataset_path: str,

    output_dir: str,

    num_epochs: Optional[int] = None,

    batch_size: Optional[int] = None,

    learning_rate: Optional[float] = None,

    device: Optional[str] = None

):
    """Train a model with parameters from config.json"""
    # Get model architecture parameters from config.json
    arch_params = get_model_architecture_params()
    logger.info(f"Loaded architecture parameters from config: {arch_params}")
    
    # Get training parameters from config.json
    if hasattr(app_config, "TRAINING_CONFIG"):
        training_config = app_config.TRAINING_CONFIG
        num_epochs = num_epochs or getattr(training_config, "NUM_EPOCHS", 10)
        learning_rate = learning_rate or getattr(training_config, "LEARNING_RATE", 1e-4)
    elif hasattr(app_config, "TRANSFORMER_CONFIG"):
        transformer_config = app_config.TRANSFORMER_CONFIG
        num_epochs = num_epochs or getattr(transformer_config, "NUM_EPOCHS", 10)
        learning_rate = learning_rate or getattr(transformer_config, "LEARNING_RATE", 1e-4)
    
    # Get data loader parameters from config.json
    if hasattr(app_config, "DATA_LOADER_CONFIG"):
        data_loader_config = app_config.DATA_LOADER_CONFIG
        batch_size = batch_size or getattr(data_loader_config, "BATCH_SIZE", 32)
    
    # Use command-line values as overrides, or fall back to defaults
    num_epochs = num_epochs or 10
    batch_size = batch_size or 32
    learning_rate = learning_rate or 1e-4
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Set device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    try:
        # Import necessary modules
        from model_Custm import Wildnerve_tlm01
        from transformers import AutoTokenizer
        from torch.utils.data import DataLoader, Dataset
        import json
        
        # Get model name from config
        model_name = getattr(app_config.TRANSFORMER_CONFIG, "MODEL_NAME", "gpt2") if hasattr(app_config, "TRANSFORMER_CONFIG") else "gpt2"
        
        # Initialize the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load dataset
        logger.info(f"Loading dataset from {dataset_path}")
        with open(dataset_path, 'r') as f:
            data = json.load(f)
        
        # Create a simple dataset class
        class TextDataset(Dataset):
            def __init__(self, texts, tokenizer, max_length):
                self.encodings = tokenizer(texts, truncation=True, padding="max_length", 
                                         max_length=max_length, return_tensors="pt")
                
            def __getitem__(self, idx):
                item = {key: val[idx] for key, val in self.encodings.items()}
                item["labels"] = item["input_ids"].clone()
                return item
            
            def __len__(self):
                return len(self.encodings["input_ids"])
        
        # Extract texts from your dataset
        texts = [item["text"] for item in data]
        
        # Create dataset and dataloader
        train_dataset = TextDataset(texts, tokenizer, arch_params["max_seq_length"])
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        # Log key parameters
        logger.info(f"Training with parameters:")
        logger.info(f"- specialization: {specialization}")
        logger.info(f"- model_name: {model_name}")
        logger.info(f"- embedding_dim: {arch_params['embedding_dim']}")
        logger.info(f"- hidden_dim: {arch_params['hidden_dim']}")
        logger.info(f"- num_heads: {arch_params['num_heads']}")
        logger.info(f"- num_layers: {arch_params['num_layers']}")
        logger.info(f"- vocab_size: {arch_params['vocab_size']}")
        logger.info(f"- num_epochs: {num_epochs}")
        logger.info(f"- batch_size: {batch_size}")
        logger.info(f"- learning_rate: {learning_rate}")
        
        # Initialize the model with architecture parameters from config
        model = Wildnerve_tlm01(
            vocab_size=arch_params["vocab_size"],
            specialization=specialization,
            dataset_path=dataset_path,
            model_name=model_name,
            embedding_dim=arch_params["embedding_dim"],
            num_heads=arch_params["num_heads"],
            hidden_dim=arch_params["hidden_dim"],
            num_layers=arch_params["num_layers"],
            output_size=arch_params["vocab_size"],
            dropout=arch_params.get("dropout", 0.1),
            max_seq_length=arch_params["max_seq_length"],
            tokenizer=tokenizer
        )
        
        # Move model to the device
        model.to(device)
        
        # Set up optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        
        # Training loop
        logger.info(f"Starting training for {num_epochs} epochs")
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            
            for batch_idx, batch in enumerate(train_dataloader):
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Forward pass
                outputs = model(batch["input_ids"], 
                              attention_mask=batch.get("attention_mask"))
                
                # Calculate loss
                loss = torch.nn.functional.cross_entropy(
                    outputs.view(-1, outputs.size(-1)), 
                    batch["labels"].view(-1)
                )
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Track loss
                total_loss += loss.item()
                
                if (batch_idx + 1) % 10 == 0:
                    logger.info(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, "
                              f"Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / len(train_dataloader)
            logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Average loss: {avg_loss:.4f}")
            
            # Save checkpoint
            checkpoint_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.bin")
            torch.save({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "loss": avg_loss,
                "config": {
                    "embedding_dim": arch_params["embedding_dim"],
                    "hidden_dim": arch_params["hidden_dim"],
                    "num_heads": arch_params["num_heads"],
                    "num_layers": arch_params["num_layers"],
                    "vocab_size": arch_params["vocab_size"]
                }
            }, checkpoint_path)
            logger.info(f"Saved checkpoint to {checkpoint_path}")
        
        # Save final model
        final_model_path = os.path.join(output_dir, f"{specialization}_final_model.bin")
        torch.save({
            "model_state_dict": model.state_dict(),
            "config": {
                "embedding_dim": arch_params["embedding_dim"],
                "hidden_dim": arch_params["hidden_dim"],
                "num_heads": arch_params["num_heads"],
                "num_layers": arch_params["num_layers"],
                "vocab_size": arch_params["vocab_size"]
            }
        }, final_model_path)
        logger.info(f"Training completed. Final model saved to {final_model_path}")
        
        return final_model_path
    
    except Exception as e:
        logger.error(f"Error during training: {e}", exc_info=True)
        return None

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a Wildnerve model")
    parser.add_argument("--specialization", type=str, default="general", help="Model specialization")
    parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset file")
    parser.add_argument("--output", type=str, default="./checkpoints", help="Output directory")
    parser.add_argument("--epochs", type=int, help="Number of training epochs (overrides config)")
    parser.add_argument("--batch-size", type=int, help="Batch size (overrides config)")
    parser.add_argument("--learning-rate", type=float, help="Learning rate (overrides config)")
    parser.add_argument("--device", type=str, help="Device to use (cuda or cpu)")
    
    args = parser.parse_args()
    
    train_model(
        specialization=args.specialization,
        dataset_path=args.dataset,
        output_dir=args.output,
        num_epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        device=args.device
    )