File size: 6,888 Bytes
2243b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

import torch
import os
import json
import logging
import torch.optim as optim
from config_physics import Config
from modeling_physics_rl import PhysicsModel

# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

def run_auto_ttt():
    print("\n" + "="*50)
    print(" πŸ€– DATA CENTER MODE: Automated TTT (Test-Time Training)")
    print("="*50)
    
    # 1. Load Model
    print("⏳ Loading Physics Model...")
    model = PhysicsModel()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Load Adapters (Generic Path Logic)
    search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"]
    for p in search_paths:
        fpath = os.path.join(p, "final_flux_adapters.pt")
        if os.path.exists(fpath):
            print(f"   Loading Flux Adapters from {fpath}...")
            adapter_states = torch.load(fpath, map_location=device)
            # Handle list vs dict (safe load)
            if isinstance(adapter_states, dict):
                # If it's a state_dict of the whole model (rare but possible)
                pass 
            elif isinstance(adapter_states, list):
                for layer, state in zip(model.flux_layers, adapter_states):
                    layer.load_state_dict(state)
            break
            
    # Load Controller
    for p in search_paths:
        fpath = os.path.join(p, "final_physics_controller.pt")
        if os.path.exists(fpath):
            print(f"   Loading Controller from {fpath}...")
            model.controller.load_state_dict(torch.load(fpath, map_location=device))
            break

    # 2. Setup Meta-Optimizer (AdamW)
    # We optimize the Controller AND the Adapter Projections
    params = list(model.controller.parameters())
    for layer in model.flux_layers:
        params.extend(list(layer.modulation_proj.parameters()))
        
    optimizer = optim.AdamW(params, lr=1e-3) # High LR for fast adaptation
    
    # 3. Define Test Cases (Scenario, Prompt, Correct Answer)
    test_cases = [
        {
            "scenario": "Zero Gravity",
            "prompt": "I drop a heavy hammer inside a space station. What happens?",
            "correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor."
        },
        {
            "scenario": "Moon Gravity",
            "prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?",
            "correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass."
        },
        {
            "scenario": "Underwater",
            "prompt": "I release a helium balloon underwater. Which way does it go?",
            "correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon."
        }
    ]
    
    print(f"\nπŸš€ Starting Automation Loop ({len(test_cases)} scenarios)...")
    
    for i, case in enumerate(test_cases):
        print(f"\n--------------------------------------------------")
        print(f"πŸ“ Scenario {i+1}: {case['scenario']}")
        print(f"   Question: \"{case['prompt']}\"")
        
        # --- Step A: Initial Inference ---
        inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device)
        
        # Thinking (Dynamics Pass)
        h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
        modulation = model.controller(h_init)
        mod_norm = modulation.norm().item()
        
        # Generate Text
        model.set_active_modulation(modulation)
        out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
        model.clear_modulation()
        
        text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip()
        print(f"   πŸ€– Initial Answer: {text_initial}")
        print(f"   πŸ“Š Modulation Norm: {mod_norm:.4f}")
        
        # --- Step B: "User" Correction (Simulated) ---
        print(f"   πŸ’‘ Teaching: \"{case['correct_answer']}\"")
        
        # Prepare Training Data
        full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}"
        inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device)
        labels = inputs_correct.input_ids.clone()
        
        # --- Step C: Test-Time Update (The Learning) ---
        model.train()
        print(f"   🧠 Adapting Weights (30 steps)...")
        
        for step in range(30): # INCREASED STEPS AGAIN
            optimizer.zero_grad()
            
            # ... (Forward/Backward logic remains same) ...
            
            # 1. Controller sees Prompt
            h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
            mod_pred = model.controller(h_prompt)
            
            # 2. LLM sees Full Sequence (forced by mod_pred)
            logits = model(inputs_correct.input_ids, forced_modulation=mod_pred)
            
            # 3. Loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), 
                shift_labels.view(-1)
            )
            
            loss.backward()
            optimizer.step()
            
            # Logging convergence
            if (step + 1) % 10 == 0:
                print(f"      Step {step+1}: Loss = {loss.item():.4f}")
                
        # --- Step D: Verify Adaptation ---
        model.eval()
        with torch.no_grad():
            h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
            mod_new = model.controller(h_new)
            model.set_active_modulation(mod_new)
            out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
            model.clear_modulation()
            
            text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip()
            
            print(f"   πŸŽ“ New Answer: {text_new}")
            print(f"   πŸ“ˆ New Mod Norm: {mod_new.norm().item():.4f}")
            
    # 4. Save TTT Weights
    print("\nπŸ’Ύ Saving Adapted Weights...")
    torch.save(model.controller.state_dict(), "ttt_physics_controller.pt")
    
    # Save Adapters
    adapter_states = [layer.state_dict() for layer in model.flux_layers]
    torch.save(adapter_states, "ttt_flux_adapters.pt")
    print("βœ… Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'")

if __name__ == "__main__":
    run_auto_ttt()