| # config.py (Updated to Disable PEFT) | |
| import torch | |
| # --- Model Configuration --- | |
| # Base model from Hugging Face (ensure it's suitable for QA) | |
| # Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad' | |
| BASE_MODEL_NAME = "bert-base-uncased" | |
| # --- RRN Specific Configuration --- | |
| # Coherence loss weight | |
| LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune | |
| # --- Delta Constraint Configuration --- | |
| DELTA_TARGET_RATIO = 0.2 # Target ratio of delta norm to h0 norm | |
| LAMBDA_DELTA_REG = 0.5 # Weight for delta regularization loss | |
| # --- Multi-step Reasoning Configuration --- | |
| NUM_REASONING_STEPS = 3 # Default number of reasoning steps (used when dynamic steps disabled) | |
| # --- Dynamic Reasoning Steps Configuration --- | |
| USE_DYNAMIC_STEPS = True # Enable/disable dynamic reasoning steps | |
| MAX_REASONING_STEPS = 5 # Maximum number of reasoning steps | |
| MIN_REASONING_STEPS = 1 # Minimum number of reasoning steps | |
| REASONING_STEP_TYPE = "learned" # Options: "fixed", "confidence", "learned" | |
| EARLY_STOP_THRESHOLD = 0.01 # Delta magnitude threshold for early stopping (used with "confidence") | |
| # --- Mixed Precision Configuration --- | |
| USE_MIXED_PRECISION = False # Enable/disable mixed precision training | |
| # --- Memory Configuration --- | |
| MEMORY_MAX_SIZE = 50 # Max number of entries in the memory | |
| MEMORY_USE_DURING_TRAINING = False # Whether to use memory during training | |
| MEMORY_RETRIEVAL_K = 3 # Number of examples to retrieve from memory | |
| # --- PEFT (LoRA) Configuration --- | |
| USE_PEFT = False # <--- SET TO False TO DISABLE PEFT --- | |
| # --- Optional: Comment out or leave the LoRA specific settings --- | |
| # LORA_R = 8 | |
| # LORA_ALPHA = 16 | |
| # LORA_DROPOUT = 0.1 | |
| # LORA_TARGET_MODULES = ["query", "value"] | |
| # --- Testing Configuration --- | |
| BYPASS_DELTA_CALCULATION = False # Set to True to bypass delta calculation for testing | |
| # --- Training Configuration --- | |
| # <<< --- Device Detection (CUDA prioritized over MPS) --- >>> | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| print("CUDA GPU acceleration is available.") | |
| elif torch.backends.mps.is_available(): | |
| DEVICE = "mps" | |
| print("Apple Silicon MPS acceleration is available.") | |
| else: | |
| DEVICE = "cpu" | |
| print("No GPU or MPS acceleration available, using CPU.") | |
| # <<< --- End of Device Detection --- >>> | |
| LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT | |
| EPOCHS = 3 | |
| # --- Adjust Batch Size for Full Fine-tuning --- | |
| # Full fine-tuning requires significantly more memory | |
| BATCH_SIZE = 4 # Start smaller, adjust based on your CUDA memory | |
| GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size | |
| # --- Dataset Configuration --- | |
| # Example for SQuAD | |
| MAX_SEQ_LENGTH = 320 # Max input sequence length for QA | |
| DOC_STRIDE = 128 # Stride for overlapping chunks for long documents | |
| print(f"Using device: {DEVICE}") | |
| print(f"Base model: {BASE_MODEL_NAME}") | |
| # Update print statement to reflect PEFT status | |
| print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled") | |