flux-test-time-training / simple_inference.py
convaiinnovations's picture
Upload simple_inference.py
c9e94b2 verified
import torch
import os
from modeling_physics_rl import PhysicsModel, Config
def simple_inference():
print("🧪 Loading Physics Model (Controller + Adapters + WALT)...")
# 1. Initialize Model Structure
model = PhysicsModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# 2. Load All Weights
try:
# A. Controller (The Brain)
if os.path.exists("final_physics_controller.pt"):
print(" loading controller...")
model.controller.load_state_dict(torch.load("final_physics_controller.pt", map_location=device))
# B. Adapters (The Muscles)
if os.path.exists("final_flux_adapters.pt"):
print(" loading adapters...")
states = torch.load("final_flux_adapters.pt", map_location=device)
# Handle list vs dict
if isinstance(states, list):
for layer, state in zip(model.flux_layers, states):
layer.load_state_dict(state)
else:
model.flux_layers.load_state_dict(states)
# C. WALT Head (The Imagination)
if os.path.exists("final_walt_head.pt"):
print(" loading walt head...")
model.walt.load_state_dict(torch.load("final_walt_head.pt", map_location=device))
print("✅ Model Assembled Successfully.")
except Exception as e:
print(f"❌ Error loading weights: {e}")
return
# 3. Inference Loop
print("\n💡 Physics-Injected Inference (Type 'exit' to quit)")
print(" Input -> [Controller] -> Modulation -> [Adapters] -> Output\n")
while True:
query = input("Query: ")
if query.lower() in ["exit", "quit"]:
break
# Format for the trained distribution
prompt = f"User: {query}\nModel: "
inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
# A. Extract Physics Layout
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
# B. Controller Decision
modulation = model.controller(h_init)
mod_norm = torch.norm(modulation).item()
# C. Inject Physics (Activate Adapters)
model.set_active_modulation(modulation)
# D. (Optional) WALT Prediction (Just to verify it runs)
# z, z_next = model.walt(h_init)
# E. Generate
out_ids = model.llm.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=model.tokenizer.eos_token_id
)
# Reset
model.clear_modulation()
response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Clean Prompt
if response.startswith(prompt):
response = response[len(prompt):].strip()
elif "Model:" in response:
response = response.split("Model:")[-1].strip()
print(f"Response: {response}")
print(f" [Physics Injection Intensity: {mod_norm:.4f}]\n")
if __name__ == "__main__":
simple_inference()