File size: 3,079 Bytes
3451ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# config.py (Updated to Disable PEFT)
import torch
# --- Model Configuration ---
# Base model from Hugging Face (ensure it's suitable for QA)
# Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad'
BASE_MODEL_NAME = "bert-base-uncased"
# --- RRN Specific Configuration ---
# Coherence loss weight
LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune
# --- Delta Constraint Configuration ---
DELTA_TARGET_RATIO = 0.2 # Target ratio of delta norm to h0 norm
LAMBDA_DELTA_REG = 0.5 # Weight for delta regularization loss
# --- Multi-step Reasoning Configuration ---
NUM_REASONING_STEPS = 3 # Default number of reasoning steps (used when dynamic steps disabled)
# --- Dynamic Reasoning Steps Configuration ---
USE_DYNAMIC_STEPS = True # Enable/disable dynamic reasoning steps
MAX_REASONING_STEPS = 5 # Maximum number of reasoning steps
MIN_REASONING_STEPS = 1 # Minimum number of reasoning steps
REASONING_STEP_TYPE = "learned" # Options: "fixed", "confidence", "learned"
EARLY_STOP_THRESHOLD = 0.01 # Delta magnitude threshold for early stopping (used with "confidence")
# --- Mixed Precision Configuration ---
USE_MIXED_PRECISION = False # Enable/disable mixed precision training
# --- Memory Configuration ---
MEMORY_MAX_SIZE = 50 # Max number of entries in the memory
MEMORY_USE_DURING_TRAINING = False # Whether to use memory during training
MEMORY_RETRIEVAL_K = 3 # Number of examples to retrieve from memory
# --- PEFT (LoRA) Configuration ---
USE_PEFT = False # <--- SET TO False TO DISABLE PEFT ---
# --- Optional: Comment out or leave the LoRA specific settings ---
# LORA_R = 8
# LORA_ALPHA = 16
# LORA_DROPOUT = 0.1
# LORA_TARGET_MODULES = ["query", "value"]
# --- Testing Configuration ---
BYPASS_DELTA_CALCULATION = False # Set to True to bypass delta calculation for testing
# --- Training Configuration ---
# <<< --- Device Detection (CUDA prioritized over MPS) --- >>>
if torch.cuda.is_available():
DEVICE = "cuda"
print("CUDA GPU acceleration is available.")
elif torch.backends.mps.is_available():
DEVICE = "mps"
print("Apple Silicon MPS acceleration is available.")
else:
DEVICE = "cpu"
print("No GPU or MPS acceleration available, using CPU.")
# <<< --- End of Device Detection --- >>>
LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT
EPOCHS = 3
# --- Adjust Batch Size for Full Fine-tuning ---
# Full fine-tuning requires significantly more memory
BATCH_SIZE = 4 # Start smaller, adjust based on your CUDA memory
GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size
# --- Dataset Configuration ---
# Example for SQuAD
MAX_SEQ_LENGTH = 320 # Max input sequence length for QA
DOC_STRIDE = 128 # Stride for overlapping chunks for long documents
print(f"Using device: {DEVICE}")
print(f"Base model: {BASE_MODEL_NAME}")
# Update print statement to reflect PEFT status
print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled")
|