File size: 3,542 Bytes
c9e94b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import os
from modeling_physics_rl import PhysicsModel, Config

def simple_inference():
    print("🧪 Loading Physics Model (Controller + Adapters + WALT)...")
    
    # 1. Initialize Model Structure
    model = PhysicsModel()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # 2. Load All Weights
    try:
        # A. Controller (The Brain)
        if os.path.exists("final_physics_controller.pt"):
            print("   loading controller...")
            model.controller.load_state_dict(torch.load("final_physics_controller.pt", map_location=device))
        
        # B. Adapters (The Muscles)
        if os.path.exists("final_flux_adapters.pt"):
            print("   loading adapters...")
            states = torch.load("final_flux_adapters.pt", map_location=device)
            # Handle list vs dict
            if isinstance(states, list):
                for layer, state in zip(model.flux_layers, states): 
                    layer.load_state_dict(state)
            else:
                model.flux_layers.load_state_dict(states)

        # C. WALT Head (The Imagination)
        if os.path.exists("final_walt_head.pt"):
            print("   loading walt head...")
            model.walt.load_state_dict(torch.load("final_walt_head.pt", map_location=device))
            
        print("✅ Model Assembled Successfully.")
        
    except Exception as e:
        print(f"❌ Error loading weights: {e}")
        return

    # 3. Inference Loop
    print("\n💡 Physics-Injected Inference (Type 'exit' to quit)")
    print("   Input -> [Controller] -> Modulation -> [Adapters] -> Output\n")
    
    while True:
        query = input("Query: ")
        if query.lower() in ["exit", "quit"]:
            break
            
        # Format for the trained distribution
        prompt = f"User: {query}\nModel: "
        inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            # A. Extract Physics Layout
            h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
            
            # B. Controller Decision
            modulation = model.controller(h_init)
            mod_norm = torch.norm(modulation).item()
            
            # C. Inject Physics (Activate Adapters)
            model.set_active_modulation(modulation)
            
            # D. (Optional) WALT Prediction (Just to verify it runs)
            # z, z_next = model.walt(h_init) 
            
            # E. Generate
            out_ids = model.llm.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=True,
                temperature=0.6,
                top_p=0.9,
                repetition_penalty=1.2,
                pad_token_id=model.tokenizer.eos_token_id
            )
            
            # Reset
            model.clear_modulation()
            
        response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
        
        # Clean Prompt
        if response.startswith(prompt):
            response = response[len(prompt):].strip()
        elif "Model:" in response:
            response = response.split("Model:")[-1].strip()
            
        print(f"Response: {response}")
        print(f"   [Physics Injection Intensity: {mod_norm:.4f}]\n")

if __name__ == "__main__":
    simple_inference()