import torch from modeling_physics_rl import PhysicsModel, Config import os import sys def interactive_session(): print("\n============================================================") print(" ๐Ÿงช FLUX TTT INFERENCE LAB (Pre-Trained)") print("Commands:") print(" - Type your question") print(" - Type 'exit' to quit") print("============================================================\n") # 1. Load Model print("๐Ÿง  Initializing Physics Model...") model = PhysicsModel() # Force GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" ๐Ÿš€ Using Device: {device}") model.to(device) # Ensure inner LLM is also on device model.llm.to(device) # 2. Load Trained TTT Weights controller_path = "final_physics_controller.pt" adapters_path = "final_flux_adapters.pt" try: # Load Controller if os.path.exists(controller_path): print(f" ๐Ÿ“‚ Loading Controller: {controller_path}") model.controller.load_state_dict(torch.load(controller_path, map_location=device)) else: print(f" โš ๏ธ Warning: Controller weights not found at {controller_path}") # Load Flux Adapters if os.path.exists(adapters_path): print(f" ๐Ÿ“‚ Loading Flux Adapters: {adapters_path}") states = torch.load(adapters_path, map_location=device) # Handle list vs ModuleList vs simple state dict if isinstance(states, list): for layer, state in zip(model.flux_layers, states): layer.load_state_dict(state) else: model.flux_layers.load_state_dict(states) else: print(f" โš ๏ธ Warning: Adapter weights not found at {adapters_path}") except Exception as e: print(f" โŒ Error loading weights: {e}") print(" โš ๏ธ Proceeding with random/base weights...") print(" โœ… Ready for Inference!\n") # 3. Interactive Loop model.eval() while True: try: user_input = input("USER: ") if user_input.lower() in ["exit", "quit"]: break if not user_input.strip(): continue # Format prompt EXACTLY like training (System Prompt + Chat) full_prompt = f"{Config.SYSTEM_PROMPT}\nUser: {user_input}\nModel:" inputs = model.tokenizer(full_prompt, return_tensors="pt").to(device) with torch.no_grad(): # 1. Predict Modulation h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) modulation = model.controller(h_init) model.set_active_modulation(modulation) # 2. Generate Response out_ids = model.llm.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.6, # Match Training (0.6) top_p=0.9, # Match Training (0.9) repetition_penalty=1.2, # Match Training (1.2) pad_token_id=model.tokenizer.eos_token_id ) model.clear_modulation() response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Clean up response to show only the model's part if "Model:" in response: response = response.split("Model:")[-1].strip() # Fallback cleanup just in case elif response.startswith(full_prompt): response = response[len(full_prompt):].strip() print(f"MODEL: {response}") print(f" [Modulation Norm: {torch.norm(modulation).item():.2f}]") print("") except KeyboardInterrupt: break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": interactive_session()