File size: 3,079 Bytes
3451ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# config.py (Updated to Disable PEFT)
import torch

# --- Model Configuration ---
# Base model from Hugging Face (ensure it's suitable for QA)
# Example: 'bert-base-uncased', 'roberta-base', 'bert-large-uncased-whole-word-masking-finetuned-squad'
BASE_MODEL_NAME = "bert-base-uncased"

# --- RRN Specific Configuration ---
# Coherence loss weight
LAMBDA_COHERENCE = 0.1 # Hyperparameter to tune

# --- Delta Constraint Configuration ---
DELTA_TARGET_RATIO = 0.2  # Target ratio of delta norm to h0 norm
LAMBDA_DELTA_REG = 0.5    # Weight for delta regularization loss

# --- Multi-step Reasoning Configuration ---
NUM_REASONING_STEPS = 3   # Default number of reasoning steps (used when dynamic steps disabled)

# --- Dynamic Reasoning Steps Configuration ---
USE_DYNAMIC_STEPS = True  # Enable/disable dynamic reasoning steps
MAX_REASONING_STEPS = 5  # Maximum number of reasoning steps
MIN_REASONING_STEPS = 1   # Minimum number of reasoning steps
REASONING_STEP_TYPE = "learned"  # Options: "fixed", "confidence", "learned"
EARLY_STOP_THRESHOLD = 0.01  # Delta magnitude threshold for early stopping (used with "confidence")

# --- Mixed Precision Configuration ---
USE_MIXED_PRECISION = False  # Enable/disable mixed precision training

# --- Memory Configuration ---
MEMORY_MAX_SIZE = 50     # Max number of entries in the memory
MEMORY_USE_DURING_TRAINING = False  # Whether to use memory during training
MEMORY_RETRIEVAL_K = 3    # Number of examples to retrieve from memory

# --- PEFT (LoRA) Configuration ---
USE_PEFT = False # <--- SET TO False TO DISABLE PEFT ---

# --- Optional: Comment out or leave the LoRA specific settings ---
# LORA_R = 8
# LORA_ALPHA = 16
# LORA_DROPOUT = 0.1
# LORA_TARGET_MODULES = ["query", "value"]

# --- Testing Configuration ---
BYPASS_DELTA_CALCULATION = False  # Set to True to bypass delta calculation for testing

# --- Training Configuration ---
# <<< --- Device Detection (CUDA prioritized over MPS) --- >>>
if torch.cuda.is_available():
    DEVICE = "cuda"
    print("CUDA GPU acceleration is available.")
elif torch.backends.mps.is_available():
    DEVICE = "mps"
    print("Apple Silicon MPS acceleration is available.")
else:
    DEVICE = "cpu"
    print("No GPU or MPS acceleration available, using CPU.")
# <<< --- End of Device Detection --- >>>

LEARNING_RATE = 1e-5 # Full fine-tuning often uses a smaller LR than PEFT
EPOCHS = 3
# --- Adjust Batch Size for Full Fine-tuning ---
# Full fine-tuning requires significantly more memory
BATCH_SIZE = 4  # Start smaller, adjust based on your CUDA memory
GRADIENT_ACCUMULATION_STEPS = 8 # Increase to compensate for smaller batch size

# --- Dataset Configuration ---
# Example for SQuAD
MAX_SEQ_LENGTH = 320 # Max input sequence length for QA
DOC_STRIDE = 128 # Stride for overlapping chunks for long documents

print(f"Using device: {DEVICE}")
print(f"Base model: {BASE_MODEL_NAME}")
# Update print statement to reflect PEFT status
print(f"Using PEFT (LoRA): {USE_PEFT} - Full Fine-tuning Enabled")