# config.py (Updated to Disable PEFT) import torch # --- Model Configuration --- # Base model from Hugging Face (ensure it's suitable for QA) # Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad' BASE_MODEL_NAME = "bert-base-uncased" # --- RRN Specific Configuration --- # Coherence loss weight LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune # --- Delta Constraint Configuration --- DELTA_TARGET_RATIO = 0.2 # Target ratio of delta norm to h0 norm LAMBDA_DELTA_REG = 0.5 # Weight for delta regularization loss # --- Multi-step Reasoning Configuration --- NUM_REASONING_STEPS = 3 # Default number of reasoning steps (used when dynamic steps disabled) # --- Dynamic Reasoning Steps Configuration --- USE_DYNAMIC_STEPS = True # Enable/disable dynamic reasoning steps MAX_REASONING_STEPS = 5 # Maximum number of reasoning steps MIN_REASONING_STEPS = 1 # Minimum number of reasoning steps REASONING_STEP_TYPE = "learned" # Options: "fixed", "confidence", "learned" EARLY_STOP_THRESHOLD = 0.01 # Delta magnitude threshold for early stopping (used with "confidence") # --- Mixed Precision Configuration --- USE_MIXED_PRECISION = False # Enable/disable mixed precision training # --- Memory Configuration --- MEMORY_MAX_SIZE = 50 # Max number of entries in the memory MEMORY_USE_DURING_TRAINING = False # Whether to use memory during training MEMORY_RETRIEVAL_K = 3 # Number of examples to retrieve from memory # --- PEFT (LoRA) Configuration --- USE_PEFT = False # <--- SET TO False TO DISABLE PEFT --- # --- Optional: Comment out or leave the LoRA specific settings --- # LORA_R = 8 # LORA_ALPHA = 16 # LORA_DROPOUT = 0.1 # LORA_TARGET_MODULES = ["query", "value"] # --- Testing Configuration --- BYPASS_DELTA_CALCULATION = False # Set to True to bypass delta calculation for testing # --- Training Configuration --- # <<< --- Device Detection (CUDA prioritized over MPS) --- >>> if torch.cuda.is_available(): DEVICE = "cuda" print("CUDA GPU acceleration is available.") elif torch.backends.mps.is_available(): DEVICE = "mps" print("Apple Silicon MPS acceleration is available.") else: DEVICE = "cpu" print("No GPU or MPS acceleration available, using CPU.") # <<< --- End of Device Detection --- >>> LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT EPOCHS = 3 # --- Adjust Batch Size for Full Fine-tuning --- # Full fine-tuning requires significantly more memory BATCH_SIZE = 4 # Start smaller, adjust based on your CUDA memory GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size # --- Dataset Configuration --- # Example for SQuAD MAX_SEQ_LENGTH = 320 # Max input sequence length for QA DOC_STRIDE = 128 # Stride for overlapping chunks for long documents print(f"Using device: {DEVICE}") print(f"Base model: {BASE_MODEL_NAME}") # Update print statement to reflect PEFT status print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled")