File size: 4,292 Bytes
059e110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5e7e4
 
 
 
 
 
 
059e110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91b2cb
 
 
 
059e110
 
 
 
 
 
 
 
 
 
 
 
e91b2cb
 
 
059e110
 
 
 
 
 
e91b2cb
 
 
 
 
 
 
059e110
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from modeling_physics_rl import PhysicsModel, Config
import os
import sys

def interactive_session():
    print("\n============================================================")
    print(" 🧪 FLUX TTT INFERENCE LAB (Pre-Trained)")
    print("Commands:")
    print("   - Type your question")
    print("   - Type 'exit' to quit")
    print("============================================================\n")

    # 1. Load Model
    print("🧠 Initializing Physics Model...")
    model = PhysicsModel()
    
    # Force GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"   🚀 Using Device: {device}")
    model.to(device)
    # Ensure inner LLM is also on device just in case (though module.to handles it)
    model.llm.to(device)
    
    # 2. Load Trained TTT Weights
    # These are the weights learned from continuous_learning_cumulative.py
    controller_path = "final_physics_controller.pt"
    adapters_path = "final_flux_adapters.pt"
    
    try:
        if os.path.exists(controller_path):
            print(f"   📂 Loading Controller: {controller_path}")
            model.controller.load_state_dict(torch.load(controller_path, map_location=device))
        else:
            print(f"   ⚠️ Warning: Controller weights not found at {controller_path}")

        if os.path.exists(adapters_path):
            print(f"   📂 Loading Flux Adapters: {adapters_path}")
            states = torch.load(adapters_path, map_location=device)
            # Handle list vs ModuleList vs simple state dict
            if isinstance(states, list):
                for layer, state in zip(model.flux_layers, states):
                    layer.load_state_dict(state)
            else:
                model.flux_layers.load_state_dict(states)
        else:
            print(f"   ⚠️ Warning: Adapter weights not found at {adapters_path}")
            
    except Exception as e:
        print(f"   ❌ Error loading weights: {e}")
        print("   ⚠️ Proceeding with random/base weights...")

    print("   ✅ Ready for Inference!\n")
    
    # 3. Interactive Loop
    model.eval()
    
    while True:
        try:
            user_input = input("USER: ")
            if user_input.lower() in ["exit", "quit"]:
                break
                
            # Generate
            # We enable modulation to see the effect of the trained controller
            # The controller predicts modulation based on the input prompt
            
            # Format the prompt to match training distribution
            prompt = f"User: {user_input}\nModel: "
            
            inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                # 1. Predict Modulation
                h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
                modulation = model.controller(h_init)
                model.set_active_modulation(modulation)
                
                # 2. Generate Response
                out_ids = model.llm.generate(
                    **inputs,
                    max_new_tokens=128,
                    do_sample=True,
                    temperature=0.6,
                    top_p=0.9,
                    repetition_penalty=1.2,
                    pad_token_id=model.tokenizer.eos_token_id
                )
                
                model.clear_modulation()
                
            response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
            
            # Clean up response (Remove the prompt part)
            if response.startswith(prompt):
                response = response[len(prompt):].strip()
            elif "Model:" in response:
                response = response.split("Model:")[-1].strip()

                
            print(f"MODEL: {response}")
            print(f"   [Modulation Norm: {torch.norm(modulation).item():.2f}]")
            print("")
            
        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"Error: {e}")

if __name__ == "__main__":
    interactive_session()