import torch from modeling_physics_rl import PhysicsModel, Config import os import sys def interactive_session(): print("\n============================================================") print(" ๐Ÿงช FLUX TTT INFERENCE LAB (Pre-Trained)") print("Commands:") print(" - Type your question") print(" - Type 'exit' to quit") print("============================================================\n") # 1. Load Model print("๐Ÿง  Initializing Physics Model...") model = PhysicsModel() # Force GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" ๐Ÿš€ Using Device: {device}") model.to(device) # Ensure inner LLM is also on device just in case (though module.to handles it) model.llm.to(device) # 2. Load Trained TTT Weights # These are the weights learned from continuous_learning_cumulative.py controller_path = "final_physics_controller.pt" adapters_path = "final_flux_adapters.pt" try: if os.path.exists(controller_path): print(f" ๐Ÿ“‚ Loading Controller: {controller_path}") model.controller.load_state_dict(torch.load(controller_path, map_location=device)) else: print(f" โš ๏ธ Warning: Controller weights not found at {controller_path}") if os.path.exists(adapters_path): print(f" ๐Ÿ“‚ Loading Flux Adapters: {adapters_path}") states = torch.load(adapters_path, map_location=device) # Handle list vs ModuleList vs simple state dict if isinstance(states, list): for layer, state in zip(model.flux_layers, states): layer.load_state_dict(state) else: model.flux_layers.load_state_dict(states) else: print(f" โš ๏ธ Warning: Adapter weights not found at {adapters_path}") except Exception as e: print(f" โŒ Error loading weights: {e}") print(" โš ๏ธ Proceeding with random/base weights...") print(" โœ… Ready for Inference!\n") # 3. Interactive Loop model.eval() while True: try: user_input = input("USER: ") if user_input.lower() in ["exit", "quit"]: break # Generate # We enable modulation to see the effect of the trained controller # The controller predicts modulation based on the input prompt # Format the prompt to match training distribution prompt = f"User: {user_input}\nModel: " inputs = model.tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): # 1. Predict Modulation h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) modulation = model.controller(h_init) model.set_active_modulation(modulation) # 2. Generate Response out_ids = model.llm.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.2, pad_token_id=model.tokenizer.eos_token_id ) model.clear_modulation() response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Clean up response (Remove the prompt part) if response.startswith(prompt): response = response[len(prompt):].strip() elif "Model:" in response: response = response.split("Model:")[-1].strip() print(f"MODEL: {response}") print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]") print("") except KeyboardInterrupt: break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": interactive_session()