flux-test-time-training / interactive_inference.py
convaiinnovations's picture
Upload interactive_inference.py
e91b2cb verified
import torch
from modeling_physics_rl import PhysicsModel, Config
import os
import sys
def interactive_session():
print("\n============================================================")
print(" πŸ§ͺ FLUX TTT INFERENCE LAB (Pre-Trained)")
print("Commands:")
print(" - Type your question")
print(" - Type 'exit' to quit")
print("============================================================\n")
# 1. Load Model
print("🧠 Initializing Physics Model...")
model = PhysicsModel()
# Force GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" πŸš€ Using Device: {device}")
model.to(device)
# Ensure inner LLM is also on device just in case (though module.to handles it)
model.llm.to(device)
# 2. Load Trained TTT Weights
# These are the weights learned from continuous_learning_cumulative.py
controller_path = "final_physics_controller.pt"
adapters_path = "final_flux_adapters.pt"
try:
if os.path.exists(controller_path):
print(f" πŸ“‚ Loading Controller: {controller_path}")
model.controller.load_state_dict(torch.load(controller_path, map_location=device))
else:
print(f" ⚠️ Warning: Controller weights not found at {controller_path}")
if os.path.exists(adapters_path):
print(f" πŸ“‚ Loading Flux Adapters: {adapters_path}")
states = torch.load(adapters_path, map_location=device)
# Handle list vs ModuleList vs simple state dict
if isinstance(states, list):
for layer, state in zip(model.flux_layers, states):
layer.load_state_dict(state)
else:
model.flux_layers.load_state_dict(states)
else:
print(f" ⚠️ Warning: Adapter weights not found at {adapters_path}")
except Exception as e:
print(f" ❌ Error loading weights: {e}")
print(" ⚠️ Proceeding with random/base weights...")
print(" βœ… Ready for Inference!\n")
# 3. Interactive Loop
model.eval()
while True:
try:
user_input = input("USER: ")
if user_input.lower() in ["exit", "quit"]:
break
# Generate
# We enable modulation to see the effect of the trained controller
# The controller predicts modulation based on the input prompt
# Format the prompt to match training distribution
prompt = f"User: {user_input}\nModel: "
inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
# 1. Predict Modulation
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = model.controller(h_init)
model.set_active_modulation(modulation)
# 2. Generate Response
out_ids = model.llm.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=model.tokenizer.eos_token_id
)
model.clear_modulation()
response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Clean up response (Remove the prompt part)
if response.startswith(prompt):
response = response[len(prompt):].strip()
elif "Model:" in response:
response = response.split("Model:")[-1].strip()
print(f"MODEL: {response}")
print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]")
print("")
except KeyboardInterrupt:
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
interactive_session()