File size: 4,292 Bytes
059e110 2c5e7e4 059e110 e91b2cb 059e110 e91b2cb 059e110 e91b2cb 059e110 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import torch
from modeling_physics_rl import PhysicsModel, Config
import os
import sys
def interactive_session():
print("\n============================================================")
print(" 🧪 FLUX TTT INFERENCE LAB (Pre-Trained)")
print("Commands:")
print(" - Type your question")
print(" - Type 'exit' to quit")
print("============================================================\n")
# 1. Load Model
print("🧠 Initializing Physics Model...")
model = PhysicsModel()
# Force GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" 🚀 Using Device: {device}")
model.to(device)
# Ensure inner LLM is also on device just in case (though module.to handles it)
model.llm.to(device)
# 2. Load Trained TTT Weights
# These are the weights learned from continuous_learning_cumulative.py
controller_path = "final_physics_controller.pt"
adapters_path = "final_flux_adapters.pt"
try:
if os.path.exists(controller_path):
print(f" 📂 Loading Controller: {controller_path}")
model.controller.load_state_dict(torch.load(controller_path, map_location=device))
else:
print(f" ⚠️ Warning: Controller weights not found at {controller_path}")
if os.path.exists(adapters_path):
print(f" 📂 Loading Flux Adapters: {adapters_path}")
states = torch.load(adapters_path, map_location=device)
# Handle list vs ModuleList vs simple state dict
if isinstance(states, list):
for layer, state in zip(model.flux_layers, states):
layer.load_state_dict(state)
else:
model.flux_layers.load_state_dict(states)
else:
print(f" ⚠️ Warning: Adapter weights not found at {adapters_path}")
except Exception as e:
print(f" ❌ Error loading weights: {e}")
print(" ⚠️ Proceeding with random/base weights...")
print(" ✅ Ready for Inference!\n")
# 3. Interactive Loop
model.eval()
while True:
try:
user_input = input("USER: ")
if user_input.lower() in ["exit", "quit"]:
break
# Generate
# We enable modulation to see the effect of the trained controller
# The controller predicts modulation based on the input prompt
# Format the prompt to match training distribution
prompt = f"User: {user_input}\nModel: "
inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
# 1. Predict Modulation
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = model.controller(h_init)
model.set_active_modulation(modulation)
# 2. Generate Response
out_ids = model.llm.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=model.tokenizer.eos_token_id
)
model.clear_modulation()
response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Clean up response (Remove the prompt part)
if response.startswith(prompt):
response = response[len(prompt):].strip()
elif "Model:" in response:
response = response.split("Model:")[-1].strip()
print(f"MODEL: {response}")
print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")
print("")
except KeyboardInterrupt:
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
interactive_session()
|