Upload 3 files
Browse files- config_physics.py +49 -41
- continuous_learning_cumulative.py +20 -2
- inference_physics.py +2 -2
config_physics.py
CHANGED
|
@@ -1,41 +1,49 @@
|
|
| 1 |
-
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
class Config:
|
| 5 |
-
# Model
|
| 6 |
-
MODEL_ID = "unsloth/gemma-3-1b-it"
|
| 7 |
-
# MODEL_ID = "google/gemma-3-1b-it"
|
| 8 |
-
|
| 9 |
-
# Dimensions
|
| 10 |
-
HIDDEN_SIZE = 1152 # Gemma 3 1B hidden size
|
| 11 |
-
LATENT_DIM = 256 # Physics latent state dimension
|
| 12 |
-
PROJECTOR_HIDDEN = 1024
|
| 13 |
-
|
| 14 |
-
# Physics Controller
|
| 15 |
-
CONTROLLER_HIDDEN = 512
|
| 16 |
-
MODULATION_DIM = 64 # Rank of modulation (similar to LoRA rank)
|
| 17 |
-
|
| 18 |
-
# Training (OPTIMIZED for Contrastive Physics)
|
| 19 |
-
BATCH_SIZE = 1 # Back to 1: Contrastive needs 4 forward passes!
|
| 20 |
-
GRAD_ACCUMULATION = 64 # 1 * 64 = 64 effective batch
|
| 21 |
-
LEARNING_RATE = 5e-5 # Reduced: Controller needs stability with Lazy Tax
|
| 22 |
-
POLICY_LR = 5e-5 # Sync with LR
|
| 23 |
-
EPOCHS = 3 # Reduced: Contrastive Loss accelerates learning
|
| 24 |
-
|
| 25 |
-
# TTT / Inference
|
| 26 |
-
TTT_STEPS = 5
|
| 27 |
-
TTT_LR = 1e-4
|
| 28 |
-
|
| 29 |
-
# Data
|
| 30 |
-
MAX_LENGTH = 256 # Reduced: Math reasoning is shorter. Saves VRAM.
|
| 31 |
-
DTYPE = torch.float32
|
| 32 |
-
|
| 33 |
-
# Physics Dimensions
|
| 34 |
-
PHYSICS_DIMS = [
|
| 35 |
-
"Gravity", "Friction", "Elasticity", "Fragility", "Density",
|
| 36 |
-
"Temperature", "Conductivity", "Magnetism",
|
| 37 |
-
"Thermodynamics", "Fluid Dynamics", "Vacuum Physics", "Electromagnetism"
|
| 38 |
-
]
|
| 39 |
-
|
| 40 |
-
# Keys
|
| 41 |
-
API_KEY_ENV = "GEMINI_API_KEY"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Config:
|
| 5 |
+
# Model
|
| 6 |
+
MODEL_ID = "unsloth/gemma-3-1b-it"
|
| 7 |
+
# MODEL_ID = "google/gemma-3-1b-it"
|
| 8 |
+
|
| 9 |
+
# Dimensions
|
| 10 |
+
HIDDEN_SIZE = 1152 # Gemma 3 1B hidden size
|
| 11 |
+
LATENT_DIM = 256 # Physics latent state dimension
|
| 12 |
+
PROJECTOR_HIDDEN = 1024
|
| 13 |
+
|
| 14 |
+
# Physics Controller
|
| 15 |
+
CONTROLLER_HIDDEN = 512
|
| 16 |
+
MODULATION_DIM = 64 # Rank of modulation (similar to LoRA rank)
|
| 17 |
+
|
| 18 |
+
# Training (OPTIMIZED for Contrastive Physics)
|
| 19 |
+
BATCH_SIZE = 1 # Back to 1: Contrastive needs 4 forward passes!
|
| 20 |
+
GRAD_ACCUMULATION = 64 # 1 * 64 = 64 effective batch
|
| 21 |
+
LEARNING_RATE = 5e-5 # Reduced: Controller needs stability with Lazy Tax
|
| 22 |
+
POLICY_LR = 5e-5 # Sync with LR
|
| 23 |
+
EPOCHS = 3 # Reduced: Contrastive Loss accelerates learning
|
| 24 |
+
|
| 25 |
+
# TTT / Inference
|
| 26 |
+
TTT_STEPS = 5
|
| 27 |
+
TTT_LR = 1e-4
|
| 28 |
+
|
| 29 |
+
# Data
|
| 30 |
+
MAX_LENGTH = 256 # Reduced: Math reasoning is shorter. Saves VRAM.
|
| 31 |
+
DTYPE = torch.float32
|
| 32 |
+
|
| 33 |
+
# Physics Dimensions
|
| 34 |
+
PHYSICS_DIMS = [
|
| 35 |
+
"Gravity", "Friction", "Elasticity", "Fragility", "Density",
|
| 36 |
+
"Temperature", "Conductivity", "Magnetism",
|
| 37 |
+
"Thermodynamics", "Fluid Dynamics", "Vacuum Physics", "Electromagnetism"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Keys
|
| 41 |
+
API_KEY_ENV = "GEMINI_API_KEY"
|
| 42 |
+
|
| 43 |
+
# Text Generation / Prompting
|
| 44 |
+
SYSTEM_PROMPT = (
|
| 45 |
+
"You are Flux, an advanced physics simulation engine. "
|
| 46 |
+
"You answer questions based on precise physical laws, distinguishing between environments like "
|
| 47 |
+
"Vacuum, Earth, Moon, and Zero-G. "
|
| 48 |
+
"Think step-by-step using first principles (Newton's Laws, Gravity, Fluid Dynamics)."
|
| 49 |
+
)
|
continuous_learning_cumulative.py
CHANGED
|
@@ -82,6 +82,18 @@ def run_cumulative_ttt():
|
|
| 82 |
|
| 83 |
# 3. Curriculum
|
| 84 |
curriculum = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
{
|
| 86 |
"id": "scenario_1",
|
| 87 |
"concept": "Zero Gravity Inertia",
|
|
@@ -142,6 +154,11 @@ def run_cumulative_ttt():
|
|
| 142 |
replay_buffer.add(task['id'], v['q'], v['a'])
|
| 143 |
replay_buffer.add(task['id'], task['prompt'], task['correction'])
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
print(" 🧠 Robust Learning (Current + Stratified History)...")
|
| 146 |
model.train()
|
| 147 |
|
|
@@ -204,7 +221,7 @@ def run_cumulative_ttt():
|
|
| 204 |
q = item['q']
|
| 205 |
target = item['a']
|
| 206 |
print(f" Q: \"{q}\"")
|
| 207 |
-
inputs = model.tokenizer(f"
|
| 208 |
with torch.no_grad():
|
| 209 |
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
|
| 210 |
mod = model.controller(h_init)
|
|
@@ -250,7 +267,7 @@ def run_cumulative_ttt():
|
|
| 250 |
print("✅ Model Saved: final_physics_controller.pt, final_flux_adapters.pt")
|
| 251 |
|
| 252 |
def calculate_loss(model, prompt, answer, device):
|
| 253 |
-
full_text = f"
|
| 254 |
inputs = model.tokenizer(full_text, return_tensors="pt").to(device)
|
| 255 |
|
| 256 |
# Forward Pass
|
|
@@ -269,6 +286,7 @@ def check_answer(task_id, text):
|
|
| 269 |
if task_id == "scenario_2": return "same time" in text or "equal" in text or "identical" in text or "same rate" in text or "neither" in text or "instant" in text or "side-by-side" in text
|
| 270 |
if task_id == "scenario_3": return "up" in text or "rise" in text or "float" in text
|
| 271 |
if task_id == "scenario_4": return "coin" in text or "thrown" in text or "initial" in text or "toss" in text or "rock" in text or "bullet" in text or "volleyball" in text
|
|
|
|
| 272 |
return False
|
| 273 |
|
| 274 |
if __name__ == "__main__":
|
|
|
|
| 82 |
|
| 83 |
# 3. Curriculum
|
| 84 |
curriculum = [
|
| 85 |
+
{
|
| 86 |
+
"id": "scenario_general",
|
| 87 |
+
"concept": "General Language Grounding",
|
| 88 |
+
"prompt": "What is the capital of France?",
|
| 89 |
+
"correction": "The capital of France is Paris.",
|
| 90 |
+
"test_variations": [
|
| 91 |
+
{"q": "Summarize: The quick brown fox jumps over the lazy dog.", "a": "A fox jumps over a dog."},
|
| 92 |
+
{"q": "What is 2 + 2?", "a": "4"},
|
| 93 |
+
{"q": "Explain what a tree is.", "a": "A tree is a tall plant with a trunk and branches made of wood."},
|
| 94 |
+
{"q": "Who wrote Romeo and Juliet?", "a": "William Shakespeare."}
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
{
|
| 98 |
"id": "scenario_1",
|
| 99 |
"concept": "Zero Gravity Inertia",
|
|
|
|
| 154 |
replay_buffer.add(task['id'], v['q'], v['a'])
|
| 155 |
replay_buffer.add(task['id'], task['prompt'], task['correction'])
|
| 156 |
|
| 157 |
+
# SKIP TRAINING for General Scenario (Just use as Replay Anchor)
|
| 158 |
+
if task['id'] == "scenario_general":
|
| 159 |
+
print(" ⏩ Added General Anchors to Buffer. Skipping Training Step.")
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
print(" 🧠 Robust Learning (Current + Stratified History)...")
|
| 163 |
model.train()
|
| 164 |
|
|
|
|
| 221 |
q = item['q']
|
| 222 |
target = item['a']
|
| 223 |
print(f" Q: \"{q}\"")
|
| 224 |
+
inputs = model.tokenizer(f"{Config.SYSTEM_PROMPT}\nUser: {q}\nModel:", return_tensors="pt").to(device)
|
| 225 |
with torch.no_grad():
|
| 226 |
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
|
| 227 |
mod = model.controller(h_init)
|
|
|
|
| 267 |
print("✅ Model Saved: final_physics_controller.pt, final_flux_adapters.pt")
|
| 268 |
|
| 269 |
def calculate_loss(model, prompt, answer, device):
|
| 270 |
+
full_text = f"{Config.SYSTEM_PROMPT}\nUser: {prompt}\nModel: {answer}"
|
| 271 |
inputs = model.tokenizer(full_text, return_tensors="pt").to(device)
|
| 272 |
|
| 273 |
# Forward Pass
|
|
|
|
| 286 |
if task_id == "scenario_2": return "same time" in text or "equal" in text or "identical" in text or "same rate" in text or "neither" in text or "instant" in text or "side-by-side" in text
|
| 287 |
if task_id == "scenario_3": return "up" in text or "rise" in text or "float" in text
|
| 288 |
if task_id == "scenario_4": return "coin" in text or "thrown" in text or "initial" in text or "toss" in text or "rock" in text or "bullet" in text or "volleyball" in text
|
| 289 |
+
if task_id == "scenario_general": return "paris" in text or "fox" in text or "4" in text or "plant" in text or "shakespeare" in text
|
| 290 |
return False
|
| 291 |
|
| 292 |
if __name__ == "__main__":
|
inference_physics.py
CHANGED
|
@@ -65,8 +65,8 @@ def interactive_session():
|
|
| 65 |
if not user_input.strip():
|
| 66 |
continue
|
| 67 |
|
| 68 |
-
# Format prompt EXACTLY like training
|
| 69 |
-
full_prompt = f"
|
| 70 |
|
| 71 |
inputs = model.tokenizer(full_prompt, return_tensors="pt").to(device)
|
| 72 |
|
|
|
|
| 65 |
if not user_input.strip():
|
| 66 |
continue
|
| 67 |
|
| 68 |
+
# Format prompt EXACTLY like training (System Prompt + Chat)
|
| 69 |
+
full_prompt = f"{Config.SYSTEM_PROMPT}\nUser: {user_input}\nModel:"
|
| 70 |
|
| 71 |
inputs = model.tokenizer(full_prompt, return_tensors="pt").to(device)
|
| 72 |
|