convaiinnovations commited on
Commit
2243b52
Β·
verified Β·
1 Parent(s): cd25d41

Upload continuous_learning_auto.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. continuous_learning_auto.py +163 -0
continuous_learning_auto.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ import json
5
+ import logging
6
+ import torch.optim as optim
7
+ from config_physics import Config
8
+ from modeling_physics_rl import PhysicsModel
9
+
10
+ # Setup Logging
11
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def run_auto_ttt():
15
+ print("\n" + "="*50)
16
+ print(" πŸ€– DATA CENTER MODE: Automated TTT (Test-Time Training)")
17
+ print("="*50)
18
+
19
+ # 1. Load Model
20
+ print("⏳ Loading Physics Model...")
21
+ model = PhysicsModel()
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+
25
+ # Load Adapters (Generic Path Logic)
26
+ search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"]
27
+ for p in search_paths:
28
+ fpath = os.path.join(p, "final_flux_adapters.pt")
29
+ if os.path.exists(fpath):
30
+ print(f" Loading Flux Adapters from {fpath}...")
31
+ adapter_states = torch.load(fpath, map_location=device)
32
+ # Handle list vs dict (safe load)
33
+ if isinstance(adapter_states, dict):
34
+ # If it's a state_dict of the whole model (rare but possible)
35
+ pass
36
+ elif isinstance(adapter_states, list):
37
+ for layer, state in zip(model.flux_layers, adapter_states):
38
+ layer.load_state_dict(state)
39
+ break
40
+
41
+ # Load Controller
42
+ for p in search_paths:
43
+ fpath = os.path.join(p, "final_physics_controller.pt")
44
+ if os.path.exists(fpath):
45
+ print(f" Loading Controller from {fpath}...")
46
+ model.controller.load_state_dict(torch.load(fpath, map_location=device))
47
+ break
48
+
49
+ # 2. Setup Meta-Optimizer (AdamW)
50
+ # We optimize the Controller AND the Adapter Projections
51
+ params = list(model.controller.parameters())
52
+ for layer in model.flux_layers:
53
+ params.extend(list(layer.modulation_proj.parameters()))
54
+
55
+ optimizer = optim.AdamW(params, lr=1e-3) # High LR for fast adaptation
56
+
57
+ # 3. Define Test Cases (Scenario, Prompt, Correct Answer)
58
+ test_cases = [
59
+ {
60
+ "scenario": "Zero Gravity",
61
+ "prompt": "I drop a heavy hammer inside a space station. What happens?",
62
+ "correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor."
63
+ },
64
+ {
65
+ "scenario": "Moon Gravity",
66
+ "prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?",
67
+ "correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass."
68
+ },
69
+ {
70
+ "scenario": "Underwater",
71
+ "prompt": "I release a helium balloon underwater. Which way does it go?",
72
+ "correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon."
73
+ }
74
+ ]
75
+
76
+ print(f"\nπŸš€ Starting Automation Loop ({len(test_cases)} scenarios)...")
77
+
78
+ for i, case in enumerate(test_cases):
79
+ print(f"\n--------------------------------------------------")
80
+ print(f"πŸ“ Scenario {i+1}: {case['scenario']}")
81
+ print(f" Question: \"{case['prompt']}\"")
82
+
83
+ # --- Step A: Initial Inference ---
84
+ inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device)
85
+
86
+ # Thinking (Dynamics Pass)
87
+ h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
88
+ modulation = model.controller(h_init)
89
+ mod_norm = modulation.norm().item()
90
+
91
+ # Generate Text
92
+ model.set_active_modulation(modulation)
93
+ out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
94
+ model.clear_modulation()
95
+
96
+ text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip()
97
+ print(f" πŸ€– Initial Answer: {text_initial}")
98
+ print(f" πŸ“Š Modulation Norm: {mod_norm:.4f}")
99
+
100
+ # --- Step B: "User" Correction (Simulated) ---
101
+ print(f" πŸ’‘ Teaching: \"{case['correct_answer']}\"")
102
+
103
+ # Prepare Training Data
104
+ full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}"
105
+ inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device)
106
+ labels = inputs_correct.input_ids.clone()
107
+
108
+ # --- Step C: Test-Time Update (The Learning) ---
109
+ model.train()
110
+ print(f" 🧠 Adapting Weights (30 steps)...")
111
+
112
+ for step in range(30): # INCREASED STEPS AGAIN
113
+ optimizer.zero_grad()
114
+
115
+ # ... (Forward/Backward logic remains same) ...
116
+
117
+ # 1. Controller sees Prompt
118
+ h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
119
+ mod_pred = model.controller(h_prompt)
120
+
121
+ # 2. LLM sees Full Sequence (forced by mod_pred)
122
+ logits = model(inputs_correct.input_ids, forced_modulation=mod_pred)
123
+
124
+ # 3. Loss
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = labels[..., 1:].contiguous()
127
+ loss = torch.nn.functional.cross_entropy(
128
+ shift_logits.view(-1, shift_logits.size(-1)),
129
+ shift_labels.view(-1)
130
+ )
131
+
132
+ loss.backward()
133
+ optimizer.step()
134
+
135
+ # Logging convergence
136
+ if (step + 1) % 10 == 0:
137
+ print(f" Step {step+1}: Loss = {loss.item():.4f}")
138
+
139
+ # --- Step D: Verify Adaptation ---
140
+ model.eval()
141
+ with torch.no_grad():
142
+ h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
143
+ mod_new = model.controller(h_new)
144
+ model.set_active_modulation(mod_new)
145
+ out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
146
+ model.clear_modulation()
147
+
148
+ text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip()
149
+
150
+ print(f" πŸŽ“ New Answer: {text_new}")
151
+ print(f" πŸ“ˆ New Mod Norm: {mod_new.norm().item():.4f}")
152
+
153
+ # 4. Save TTT Weights
154
+ print("\nπŸ’Ύ Saving Adapted Weights...")
155
+ torch.save(model.controller.state_dict(), "ttt_physics_controller.pt")
156
+
157
+ # Save Adapters
158
+ adapter_states = [layer.state_dict() for layer in model.flux_layers]
159
+ torch.save(adapter_states, "ttt_flux_adapters.pt")
160
+ print("βœ… Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'")
161
+
162
+ if __name__ == "__main__":
163
+ run_auto_ttt()