import torch import os from modeling_physics_rl import PhysicsModel, Config def simple_inference(): print("๐Ÿงช Loading Physics Model (Controller + Adapters + WALT)...") # 1. Initialize Model Structure model = PhysicsModel() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # 2. Load All Weights try: # A. Controller (The Brain) if os.path.exists("final_physics_controller.pt"): print(" loading controller...") model.controller.load_state_dict(torch.load("final_physics_controller.pt", map_location=device)) # B. Adapters (The Muscles) if os.path.exists("final_flux_adapters.pt"): print(" loading adapters...") states = torch.load("final_flux_adapters.pt", map_location=device) # Handle list vs dict if isinstance(states, list): for layer, state in zip(model.flux_layers, states): layer.load_state_dict(state) else: model.flux_layers.load_state_dict(states) # C. WALT Head (The Imagination) if os.path.exists("final_walt_head.pt"): print(" loading walt head...") model.walt.load_state_dict(torch.load("final_walt_head.pt", map_location=device)) print("โœ… Model Assembled Successfully.") except Exception as e: print(f"โŒ Error loading weights: {e}") return # 3. Inference Loop print("\n๐Ÿ’ก Physics-Injected Inference (Type 'exit' to quit)") print(" Input -> [Controller] -> Modulation -> [Adapters] -> Output\n") while True: query = input("Query: ") if query.lower() in ["exit", "quit"]: break # Format for the trained distribution prompt = f"User: {query}\nModel: " inputs = model.tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): # A. Extract Physics Layout h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) # B. Controller Decision modulation = model.controller(h_init) mod_norm = torch.norm(modulation).item() # C. Inject Physics (Activate Adapters) model.set_active_modulation(modulation) # D. (Optional) WALT Prediction (Just to verify it runs) # z, z_next = model.walt(h_init) # E. Generate out_ids = model.llm.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.2, pad_token_id=model.tokenizer.eos_token_id ) # Reset model.clear_modulation() response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Clean Prompt if response.startswith(prompt): response = response[len(prompt):].strip() elif "Model:" in response: response = response.split("Model:")[-1].strip() print(f"Response: {response}") print(f" [Physics Injection Intensity: {mod_norm:.4f}]\n") if __name__ == "__main__": simple_inference()