convaiinnovations commited on
Commit
c9e94b2
·
verified ·
1 Parent(s): e91b2cb

Upload simple_inference.py

Browse files
Files changed (1) hide show
  1. simple_inference.py +96 -0
simple_inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from modeling_physics_rl import PhysicsModel, Config
4
+
5
+ def simple_inference():
6
+ print("🧪 Loading Physics Model (Controller + Adapters + WALT)...")
7
+
8
+ # 1. Initialize Model Structure
9
+ model = PhysicsModel()
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # 2. Load All Weights
15
+ try:
16
+ # A. Controller (The Brain)
17
+ if os.path.exists("final_physics_controller.pt"):
18
+ print(" loading controller...")
19
+ model.controller.load_state_dict(torch.load("final_physics_controller.pt", map_location=device))
20
+
21
+ # B. Adapters (The Muscles)
22
+ if os.path.exists("final_flux_adapters.pt"):
23
+ print(" loading adapters...")
24
+ states = torch.load("final_flux_adapters.pt", map_location=device)
25
+ # Handle list vs dict
26
+ if isinstance(states, list):
27
+ for layer, state in zip(model.flux_layers, states):
28
+ layer.load_state_dict(state)
29
+ else:
30
+ model.flux_layers.load_state_dict(states)
31
+
32
+ # C. WALT Head (The Imagination)
33
+ if os.path.exists("final_walt_head.pt"):
34
+ print(" loading walt head...")
35
+ model.walt.load_state_dict(torch.load("final_walt_head.pt", map_location=device))
36
+
37
+ print("✅ Model Assembled Successfully.")
38
+
39
+ except Exception as e:
40
+ print(f"❌ Error loading weights: {e}")
41
+ return
42
+
43
+ # 3. Inference Loop
44
+ print("\n💡 Physics-Injected Inference (Type 'exit' to quit)")
45
+ print(" Input -> [Controller] -> Modulation -> [Adapters] -> Output\n")
46
+
47
+ while True:
48
+ query = input("Query: ")
49
+ if query.lower() in ["exit", "quit"]:
50
+ break
51
+
52
+ # Format for the trained distribution
53
+ prompt = f"User: {query}\nModel: "
54
+ inputs = model.tokenizer(prompt, return_tensors="pt").to(device)
55
+
56
+ with torch.no_grad():
57
+ # A. Extract Physics Layout
58
+ h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
59
+
60
+ # B. Controller Decision
61
+ modulation = model.controller(h_init)
62
+ mod_norm = torch.norm(modulation).item()
63
+
64
+ # C. Inject Physics (Activate Adapters)
65
+ model.set_active_modulation(modulation)
66
+
67
+ # D. (Optional) WALT Prediction (Just to verify it runs)
68
+ # z, z_next = model.walt(h_init)
69
+
70
+ # E. Generate
71
+ out_ids = model.llm.generate(
72
+ **inputs,
73
+ max_new_tokens=128,
74
+ do_sample=True,
75
+ temperature=0.6,
76
+ top_p=0.9,
77
+ repetition_penalty=1.2,
78
+ pad_token_id=model.tokenizer.eos_token_id
79
+ )
80
+
81
+ # Reset
82
+ model.clear_modulation()
83
+
84
+ response = model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
85
+
86
+ # Clean Prompt
87
+ if response.startswith(prompt):
88
+ response = response[len(prompt):].strip()
89
+ elif "Model:" in response:
90
+ response = response.split("Model:")[-1].strip()
91
+
92
+ print(f"Response: {response}")
93
+ print(f" [Physics Injection Intensity: {mod_norm:.4f}]\n")
94
+
95
+ if __name__ == "__main__":
96
+ simple_inference()