File size: 19,151 Bytes
3451ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# train.py (Updated for Full Fine-tuning)
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler  # For mixed precision training (updated import)
from transformers import AutoTokenizer, default_data_collator
from datasets import load_dataset
from tqdm.auto import tqdm # Progress bar
import os
import evaluate # For metrics
import logging # Optional: Better logging
import multiprocessing # For Windows multiprocessing support
import argparse # For command line arguments

# Import our custom modules and config
import config
from model import EnhancedRRN_QA_Model

# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Train RRN QA Model")
    parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
    parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
    parser.add_argument(
        "--subset_percentage", 
        type=float, 
        default=100.0,
        help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
    )
    parser.add_argument(
        "--bypass_delta", 
        action="store_true",
        help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
    )
    args = parser.parse_args()
    
    # Set bypass delta calculation flag if specified
    if args.bypass_delta:
        logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
        config.BYPASS_DELTA_CALCULATION = True
    else:
        config.BYPASS_DELTA_CALCULATION = False
    
    # --- 1. Load Tokenizer and Model ---
    if args.checkpoint:
        logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
        tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
        
        logger.info(f"Loading model from checkpoint: {args.checkpoint}")
        # Initialize the model with base architecture
        model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
        
        # Check for enhanced model components
        gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
        is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
        
        # Load custom module weights
        logger.info("Loading model components...")
        model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
        model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
        
        # Load gating mechanism if available
        if is_enhanced_checkpoint:
            logger.info("Loading gating mechanism...")
            model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
            
        # Load step controller if available (for learned dynamic steps)
        step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
        if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
            logger.info("Loading step controller for learned dynamic steps...")
            model.step_controller.load_state_dict(torch.load(step_controller_path))
    else:
        logger.info("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)

        logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
        model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
    
    model.to(config.DEVICE)

    # --- 2. Load and Preprocess Dataset ---
    logger.info("Loading SQuAD dataset...")
    raw_datasets = load_dataset("squad")
    
    # Handle dataset subsetting
    subset_percentage = args.subset_percentage
    if subset_percentage < 100.0:
        original_train_size = len(raw_datasets["train"])
        
        # Calculate subset size and validate
        subset_percentage = max(0.1, min(100.0, subset_percentage))  # Clamp between 0.1% and 100%
        train_subset_size = int(original_train_size * subset_percentage / 100)
        train_subset_size = max(100, min(original_train_size, train_subset_size))  # Ensure reasonable bounds
        
        # Create reproducible subset with fixed seed for consistency
        subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
        raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
        
        logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
    else:
        logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")

    question_column_name = "question"
    context_column_name = "context"
    answer_column_name = "answers"
    pad_on_right = tokenizer.padding_side == "right"

    def prepare_train_features(examples):
        examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=config.MAX_SEQ_LENGTH,
            stride=config.DOC_STRIDE,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        offset_mapping = tokenized_examples.pop("offset_mapping")
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)
            sequence_ids = tokenized_examples.sequence_ids(i)
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]

            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
        return tokenized_examples

    logger.info("Preprocessing datasets...")
    # Use single process on Windows to avoid multiprocessing issues
    tokenized_datasets = raw_datasets.map(
        prepare_train_features,
        batched=True,
        remove_columns=raw_datasets["train"].column_names,
        num_proc=1 # Use single process to avoid Windows multiprocessing issues
    )

    data_collator = default_data_collator
    train_dataloader = DataLoader(
        tokenized_datasets["train"],
        shuffle=True,
        collate_fn=data_collator,
        batch_size=config.BATCH_SIZE
    )
    # Consider adding validation dataloader setup here as well
    # eval_dataloader = DataLoader(...)

    # --- 3. Setup Optimizer ---
    logger.info("Setting up optimizer for FULL model fine-tuning...")
    # Optimize all parameters since PEFT is disabled
    optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)

    logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
    # Calculate total steps considering gradient accumulation
    num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
    num_training_steps = config.EPOCHS * num_update_steps_per_epoch
    logger.info(f"Total optimization steps: {num_training_steps}")


    # --- 4. Initialize Mixed Precision Training ---
    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION)  # Updated to fix deprecation warning
    
    # Log mixed precision and dynamic steps status
    if config.USE_MIXED_PRECISION:
        logger.info("Mixed precision training (FP16) enabled")
    if config.USE_DYNAMIC_STEPS:
        logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
        logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
    
    # Log bypass delta calculation status
    if config.BYPASS_DELTA_CALCULATION:
        logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")

    # --- 5. Training Loop ---
    logger.info("***** Starting Training *****")
    logger.info(f"  Num examples = {len(tokenized_datasets['train'])}")
    logger.info(f"  Num Epochs = {config.EPOCHS}")
    logger.info(f"  Instantaneous batch size per device = {config.BATCH_SIZE}")
    logger.info(f"  Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
    logger.info(f"  Total optimization steps = {num_training_steps}")
    
    # Add note about subset training if applicable
    if subset_percentage < 100.0:
        logger.info(f"  NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")


    model.train() # Set model to training mode
    global_step = 0
    total_loss = 0.0 # Use float for accumulated loss

    # Start from specified epoch (default is 0 if not provided)
    start_epoch = args.start_epoch
    
    for epoch in range(start_epoch, config.EPOCHS):
        logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")

        for step, batch in enumerate(progress_bar):
            # Move batch to device
            # Ensure only tensors are moved, handle potential non-tensor data if any
            batch_on_device = {}
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    batch_on_device[k] = v.to(config.DEVICE)
                # else: # Handle or skip non-tensor items if necessary
                #     batch_on_device[k] = v

            try:
                # Forward pass with autocast for mixed precision
                with autocast('cuda', enabled=config.USE_MIXED_PRECISION):  # Updated to fix deprecation warning
                    outputs = model(
                        input_ids=batch_on_device.get("input_ids"),
                        attention_mask=batch_on_device.get("attention_mask"),
                        token_type_ids=batch_on_device.get("token_type_ids"),
                        start_positions=batch_on_device.get("start_positions"),
                        end_positions=batch_on_device.get("end_positions"),
                        use_memory=False # Disable memory during training steps
                    )
                    loss = outputs.loss

                    if loss is None:
                        logger.warning(f"Step {step}: Loss is None. Skipping batch.")
                        continue

                    # Scale loss for gradient accumulation
                    loss = loss / config.GRADIENT_ACCUMULATION_STEPS
                
                # Accumulate loss value for logging (before backward)
                total_loss += loss.item()
                
                # Scale loss and perform backward pass with AMP
                scaler.scale(loss).backward()

            except Exception as e:
                logger.error(f"Error during forward/backward pass at step {step}: {e}")
                # Optional: Add more detailed error handling or debugging info
                # logger.error(f"Batch keys: {batch.keys()}")
                # logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}")
                raise e # Re-raise the exception to stop training

            # Optimizer step (perform step only after accumulating gradients)
            if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
                # Unscale before optimizer step (to check for infs/NaNs)
                scaler.unscale_(optimizer)
                
                # Clip gradients to avoid explosion (optional but recommended with mixed precision)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Step with scaler
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() # Reset gradients for the next accumulation cycle
                global_step += 1

                # Log progress periodically
                if global_step % 50 == 0: # Log every 50 optimization steps
                    avg_loss = total_loss / 50 # Average loss over the last 50 steps
                    logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
                    total_loss = 0.0 # Reset loss accumulator

                # Update progress bar description with current step loss and steps info
                postfix = {
                    "Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}", 
                    "Step": global_step
                }
                
                # Add steps info if using dynamic steps
                if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
                    if 'steps_taken' in model.custom_outputs:
                        postfix["Steps"] = model.custom_outputs['steps_taken']
                
                progress_bar.set_postfix(postfix)


        # --- (Optional) Evaluation at the end of each epoch ---
        # logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---")
        # model.eval()
        # # Add evaluation loop here (requires validation dataloader, postprocessing, metrics)
        # model.train() # Set back to train mode

        # --- Save Model Checkpoint ---
        output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"--- Saving model checkpoint to {output_dir} ---")

        # --- Saving Logic for Enhanced Model ---
        try:
            logger.info(f"Saving enhanced model components to {output_dir}")
            # Save base model using its save_pretrained
            model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
            
            # Save all custom modules' state dicts
            torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
            torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
            torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
            
            # Save step controller if using learned dynamic steps
            if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
                torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
                logger.info("Saved step controller for learned dynamic steps")
            
            # Save tokenizer
            tokenizer.save_pretrained(output_dir)
            
            # Save configuration
            with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
                import json
                config_dict = {
                    "num_reasoning_steps": config.NUM_REASONING_STEPS,
                    "delta_target_ratio": config.DELTA_TARGET_RATIO,
                    "lambda_coherence": config.LAMBDA_COHERENCE,
                    "lambda_delta_reg": config.LAMBDA_DELTA_REG,
                    "memory_max_size": config.MEMORY_MAX_SIZE,
                    "memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
                    "use_mixed_precision": config.USE_MIXED_PRECISION,
                    "bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
                }
                
                # Add dynamic steps configuration if enabled
                if config.USE_DYNAMIC_STEPS:
                    config_dict.update({
                        "use_dynamic_steps": config.USE_DYNAMIC_STEPS,
                        "max_reasoning_steps": config.MAX_REASONING_STEPS,
                        "min_reasoning_steps": config.MIN_REASONING_STEPS,
                        "reasoning_step_type": config.REASONING_STEP_TYPE,
                        "early_stop_threshold": config.EARLY_STOP_THRESHOLD
                    })
                
                json.dump(config_dict, f, indent=2)
            
            logger.info("Enhanced model checkpoint saved successfully.")
        except Exception as e:
            logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")


    logger.info("\n***** Training finished *****")

if __name__ == "__main__":
    # This is required for Windows to properly handle multiprocessing
    multiprocessing.freeze_support()
    main()

# Example usage:
# Train on full dataset (default):
# python train.py

# Train on 10% of data for faster iterations:
# python train.py --subset_percentage 10.0

# Train on 1% for very quick testing:
# python train.py --subset_percentage 1.0

# Resume training from checkpoint with subset:
# python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0

# Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)):
# python train.py --bypass_delta --subset_percentage 1.0