DaertML commited on
Commit
78e22d5
·
verified ·
1 Parent(s): 5721942

Upload 2 files

Browse files
Files changed (2) hide show
  1. batch_inference.py +304 -0
  2. questions.txt +7 -0
batch_inference.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive REPL for testing trained physics problem-solving model.
3
+ """
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import yaml
9
+
10
+ from qwen2_model import Transformer
11
+ from tokenizer import Tokenizer
12
+ from generation_utils import generate
13
+ from tokenizer_wrapper import decode_token_ids
14
+
15
+
16
+ SYSTEM_MESSAGE = (
17
+ "You are a helpful physics tutor. You first think about the reasoning process "
18
+ "in your mind and then provide the user with the answer."
19
+ )
20
+ USER_TEMPLATE = (
21
+ "{question}\n"
22
+ "Show your reasoning in <think> </think> tags. "
23
+ "Then provide your final answer in <answer> </answer> tags."
24
+ )
25
+ RESPONSE_PROMPT = "Let me solve this step by step.\n<think>"
26
+
27
+
28
+ def load_model_and_tokenizer(config_path, checkpoint_path=None):
29
+ """Load model and tokenizer from config and checkpoint."""
30
+ with open(config_path, "r") as f:
31
+ config = yaml.safe_load(f)
32
+
33
+ pretrained_model_path = Path(config["model"]["pretrained_model_path"])
34
+ device = torch.device(config["model"]["device"])
35
+
36
+ dtype_map = {
37
+ "bfloat16": torch.bfloat16,
38
+ "float16": torch.float16,
39
+ "float32": torch.float32,
40
+ }
41
+ dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16)
42
+
43
+ # Load tokenizer
44
+ tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json"))
45
+
46
+ # Load model
47
+ model = Transformer.from_pretrained(pretrained_model_path, device=device)
48
+
49
+ # Load checkpoint if provided
50
+ if checkpoint_path:
51
+ print(f"Loading checkpoint from {checkpoint_path}...")
52
+ checkpoint = torch.load(checkpoint_path, map_location=device)
53
+
54
+ # Handle different checkpoint formats
55
+ if isinstance(checkpoint, dict):
56
+ if "model_state_dict" in checkpoint:
57
+ # Checkpoint contains model_state_dict, optimizer_state_dict, etc.
58
+ state_dict = checkpoint["model_state_dict"]
59
+ print(f"Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
60
+ else:
61
+ # Checkpoint is already a state dict
62
+ state_dict = checkpoint
63
+ else:
64
+ state_dict = checkpoint
65
+
66
+ model.load_state_dict(state_dict)
67
+ print("Checkpoint loaded successfully!")
68
+
69
+ model.eval()
70
+
71
+ return model, tokenizer, device, dtype, config
72
+
73
+ def generate_response(model, tokenizer, question, device, dtype, max_gen_len=512, temperature=0.7, top_p=0.9):
74
+ """Generate a response for a given physics question."""
75
+ # Format the prompt
76
+ user_message = USER_TEMPLATE.format(question=question)
77
+ prefix = tokenizer.encode_chat_with_response_prompt(
78
+ [
79
+ {"role": "system", "content": SYSTEM_MESSAGE},
80
+ {"role": "user", "content": user_message},
81
+ ],
82
+ RESPONSE_PROMPT,
83
+ )
84
+
85
+ # Tokenize
86
+ tokens = tokenizer.tokenize(prefix)
87
+ prefix_token_ids = tokens.ids
88
+
89
+ # Generate
90
+ print("\nGenerating response...")
91
+ with torch.inference_mode():
92
+ generated_token_ids, is_finished = generate(
93
+ model=model,
94
+ tokenizer=tokenizer,
95
+ prompt_token_ids=prefix_token_ids,
96
+ max_gen_len=max_gen_len,
97
+ temperature=temperature,
98
+ top_p=top_p,
99
+ device=device,
100
+ dtype=dtype,
101
+ )
102
+
103
+ # Decode
104
+ generated_text = decode_token_ids(tokenizer, generated_token_ids)
105
+
106
+ return prefix + generated_text, is_finished
107
+
108
+
109
+ def extract_answer(text):
110
+ """Extract the answer from <answer> tags."""
111
+ import re
112
+ answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
113
+ if answer_match:
114
+ return answer_match.group(1).strip()
115
+ return None
116
+
117
+
118
+ def print_response(full_text):
119
+ """Pretty print the model's response."""
120
+ import re
121
+
122
+ # Try to extract think and answer sections
123
+ think_match = re.search(r"<think>(.*?)</think>", full_text, re.DOTALL)
124
+ answer_match = re.search(r"<answer>(.*?)</answer>", full_text, re.DOTALL)
125
+
126
+ print("\n" + "="*80)
127
+
128
+ if think_match:
129
+ print("\n🤔 REASONING:")
130
+ print("-" * 80)
131
+ print(think_match.group(1).strip())
132
+
133
+ if answer_match:
134
+ print("\n✅ ANSWER:")
135
+ print("-" * 80)
136
+ print(answer_match.group(1).strip())
137
+ else:
138
+ print("\n⚠️ WARNING: No answer tags found in response")
139
+ print("\nFull response:")
140
+ print("-" * 80)
141
+ print(full_text)
142
+
143
+ print("="*80 + "\n")
144
+
145
+
146
+ def interactive_mode(model, tokenizer, device, dtype, config):
147
+ """Run interactive REPL mode."""
148
+ print("\n" + "="*80)
149
+ print("Physics Problem Solver - Interactive Mode")
150
+ print("="*80)
151
+ print("\nCommands:")
152
+ print(" - Type your physics question and press Enter")
153
+ print(" - Type 'quit' or 'exit' to exit")
154
+ print(" - Type 'config' to change generation parameters")
155
+ print(" - Type 'example' to see example questions")
156
+ print("="*80 + "\n")
157
+
158
+ # Default generation parameters
159
+ max_gen_len = config["training"].get("max_gen_len", 512)
160
+ temperature = 0.7
161
+ top_p = 0.9
162
+
163
+ while True:
164
+ try:
165
+ user_input = input("\n📝 Enter physics question (or command): ").strip()
166
+
167
+ if not user_input:
168
+ continue
169
+
170
+ if user_input.lower() in ['quit', 'exit', 'q']:
171
+ print("\nGoodbye! 👋")
172
+ break
173
+
174
+ if user_input.lower() == 'example':
175
+ print("\nExample questions:")
176
+ print(" 1. A ball is thrown upward with velocity 20 m/s. What is its maximum height?")
177
+ print(" 2. Calculate the force needed to accelerate a 5kg object at 3 m/s²")
178
+ print(" 3. What is the wavelength of light with frequency 5×10¹⁴ Hz?")
179
+ print(" 4. A 2kg block slides down a 30° incline. What is its acceleration?")
180
+ continue
181
+
182
+ if user_input.lower() == 'config':
183
+ print(f"\nCurrent settings:")
184
+ print(f" max_gen_len: {max_gen_len}")
185
+ print(f" temperature: {temperature}")
186
+ print(f" top_p: {top_p}")
187
+
188
+ try:
189
+ new_max_len = input(f"\nNew max_gen_len [{max_gen_len}]: ").strip()
190
+ if new_max_len:
191
+ max_gen_len = int(new_max_len)
192
+
193
+ new_temp = input(f"New temperature [{temperature}]: ").strip()
194
+ if new_temp:
195
+ temperature = float(new_temp)
196
+
197
+ new_top_p = input(f"New top_p [{top_p}]: ").strip()
198
+ if new_top_p:
199
+ top_p = float(new_top_p)
200
+
201
+ print("\n✓ Configuration updated!")
202
+ except ValueError:
203
+ print("\n✗ Invalid input. Configuration unchanged.")
204
+ continue
205
+
206
+ # Generate response
207
+ full_text, is_finished = generate_response(
208
+ model=model,
209
+ tokenizer=tokenizer,
210
+ question=user_input,
211
+ device=device,
212
+ dtype=dtype,
213
+ max_gen_len=max_gen_len,
214
+ temperature=temperature,
215
+ top_p=top_p,
216
+ )
217
+
218
+ # Print response
219
+ print_response(full_text)
220
+
221
+ if not is_finished:
222
+ print("⚠️ Note: Response was truncated (reached max_gen_len)")
223
+
224
+ except KeyboardInterrupt:
225
+ print("\n\nInterrupted. Type 'quit' to exit.\n")
226
+ continue
227
+ except Exception as e:
228
+ print(f"\n✗ Error: {e}\n")
229
+ continue
230
+
231
+
232
+ def batch_inference_mode(model, tokenizer, device, dtype, config, questions_file, output_file):
233
+ """Run batch inference on a file of questions."""
234
+ print(f"\nRunning batch inference on {questions_file}...")
235
+
236
+ max_gen_len = config["training"].get("max_gen_len", 512)
237
+
238
+ # Read questions
239
+ with open(questions_file, 'r') as f:
240
+ questions = [line.strip() for line in f if line.strip()]
241
+
242
+ print(f"Found {len(questions)} questions")
243
+
244
+ results = []
245
+ for i, question in enumerate(questions, 1):
246
+ print(f"\n[{i}/{len(questions)}] Processing: {question[:60]}...")
247
+
248
+ full_text, is_finished = generate_response(
249
+ model=model,
250
+ tokenizer=tokenizer,
251
+ question=question,
252
+ device=device,
253
+ dtype=dtype,
254
+ max_gen_len=max_gen_len,
255
+ temperature=0.7,
256
+ top_p=0.9,
257
+ )
258
+
259
+ answer = extract_answer(full_text)
260
+
261
+ results.append({
262
+ 'question': question,
263
+ 'full_response': full_text,
264
+ 'answer': answer,
265
+ 'is_finished': is_finished,
266
+ })
267
+
268
+ # Save results
269
+ import json
270
+ with open(output_file, 'w') as f:
271
+ json.dump(results, f, indent=2)
272
+
273
+ print(f"\n✓ Results saved to {output_file}")
274
+
275
+
276
+ def main():
277
+ parser = argparse.ArgumentParser(description="Interactive inference for physics problem solver")
278
+ parser.add_argument("--config", type=str, required=True, help="Path to config YAML file")
279
+ parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint (optional)")
280
+ parser.add_argument("--batch", action="store_true", help="Run batch inference mode")
281
+ parser.add_argument("--questions", type=str, help="Path to questions file (for batch mode)")
282
+ parser.add_argument("--output", type=str, default="results.json", help="Output file (for batch mode)")
283
+
284
+ args = parser.parse_args()
285
+
286
+ # Load model and tokenizer
287
+ print("Loading model and tokenizer...")
288
+ model, tokenizer, device, dtype, config = load_model_and_tokenizer(
289
+ args.config,
290
+ args.checkpoint
291
+ )
292
+ print("✓ Model loaded successfully!\n")
293
+
294
+ if args.batch:
295
+ if not args.questions:
296
+ print("Error: --questions file required for batch mode")
297
+ return
298
+ batch_inference_mode(model, tokenizer, device, dtype, config, args.questions, args.output)
299
+ else:
300
+ interactive_mode(model, tokenizer, device, dtype, config)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ main()
questions.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ A train moves 500 kilometers in 5 hours. What is its speed in meters per second?
2
+ Name the following organic compound: CH3CH2CH2OH.
3
+ Implement a python code that outputs the shortest path in a graph
4
+ Proof the Pythagoras theorem
5
+ Current Position Details:White Pieces: King on h1, Rook on e1. Black Pieces: King on h8, Rook on g8. The Challenge: It is White's turn to move. Find the move that results in checkmate (Mate in 1).
6
+ Facing a $8 bet into a $20 pot on the $A♠ 8♣ 4♠$ flop holding $K♠ 9♠$, your pot odds are $28:8$ (or $3.5$-to-$1$), which is slightly less than the $4.1$-to-$1$ needed for a pure $9$-out flush draw, but the call is justified by implied odds since you can expect to win a larger bet if you hit your hand.
7
+