Song Yi commited on
Commit
c913ed4
·
verified ·
1 Parent(s): 9e0ba17

Create inference_math.py

Browse files
Files changed (1) hide show
  1. inference_math.py +432 -0
inference_math.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kirim-1-Math Inference Script
3
+ Mathematical reasoning with tool calling capabilities
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import re
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from typing import List, Dict, Any, Optional
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+
15
+ class MathToolExecutor:
16
+ """Execute mathematical tools called by the model"""
17
+
18
+ def __init__(self):
19
+ try:
20
+ import sympy as sp
21
+ import numpy as np
22
+ self.sp = sp
23
+ self.np = np
24
+ except ImportError:
25
+ print("Warning: SymPy or NumPy not installed. Tool execution limited.")
26
+ self.sp = None
27
+ self.np = None
28
+
29
+ def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
30
+ """Execute a tool and return results"""
31
+ try:
32
+ if tool_name == "calculator":
33
+ return self._calculator(arguments)
34
+ elif tool_name == "symbolic_solver":
35
+ return self._symbolic_solver(arguments)
36
+ elif tool_name == "derivative":
37
+ return self._derivative(arguments)
38
+ elif tool_name == "integrate":
39
+ return self._integrate(arguments)
40
+ elif tool_name == "simplify":
41
+ return self._simplify(arguments)
42
+ elif tool_name == "latex_formatter":
43
+ return self._latex_formatter(arguments)
44
+ else:
45
+ return f"Unknown tool: {tool_name}"
46
+ except Exception as e:
47
+ return f"Tool execution error: {str(e)}"
48
+
49
+ def _calculator(self, args: Dict) -> str:
50
+ """Precise calculator"""
51
+ expr = args.get("expression", "")
52
+ precision = args.get("precision", 15)
53
+
54
+ if not self.sp:
55
+ return "SymPy not available"
56
+
57
+ try:
58
+ result = self.sp.sympify(expr)
59
+ result = self.sp.N(result, precision)
60
+ return f"Result: {result}"
61
+ except Exception as e:
62
+ return f"Calculation error: {e}"
63
+
64
+ def _symbolic_solver(self, args: Dict) -> str:
65
+ """Solve equations symbolically"""
66
+ equation = args.get("equation", "")
67
+ variable = args.get("variable", "x")
68
+
69
+ if not self.sp:
70
+ return "SymPy not available"
71
+
72
+ try:
73
+ var = self.sp.Symbol(variable)
74
+ eq = self.sp.sympify(equation)
75
+ solutions = self.sp.solve(eq, var)
76
+ return f"Solutions: {solutions}"
77
+ except Exception as e:
78
+ return f"Solver error: {e}"
79
+
80
+ def _derivative(self, args: Dict) -> str:
81
+ """Calculate derivatives"""
82
+ function = args.get("function", "")
83
+ variable = args.get("variable", "x")
84
+ order = args.get("order", 1)
85
+
86
+ if not self.sp:
87
+ return "SymPy not available"
88
+
89
+ try:
90
+ var = self.sp.Symbol(variable)
91
+ func = self.sp.sympify(function)
92
+ result = self.sp.diff(func, var, order)
93
+ return f"Derivative: {result}"
94
+ except Exception as e:
95
+ return f"Derivative error: {e}"
96
+
97
+ def _integrate(self, args: Dict) -> str:
98
+ """Calculate integrals"""
99
+ function = args.get("function", "")
100
+ variable = args.get("variable", "x")
101
+ lower = args.get("lower_bound")
102
+ upper = args.get("upper_bound")
103
+
104
+ if not self.sp:
105
+ return "SymPy not available"
106
+
107
+ try:
108
+ var = self.sp.Symbol(variable)
109
+ func = self.sp.sympify(function)
110
+
111
+ if lower is not None and upper is not None:
112
+ result = self.sp.integrate(func, (var, lower, upper))
113
+ else:
114
+ result = self.sp.integrate(func, var)
115
+
116
+ return f"Integral: {result}"
117
+ except Exception as e:
118
+ return f"Integration error: {e}"
119
+
120
+ def _simplify(self, args: Dict) -> str:
121
+ """Simplify expressions"""
122
+ expression = args.get("expression", "")
123
+
124
+ if not self.sp:
125
+ return "SymPy not available"
126
+
127
+ try:
128
+ expr = self.sp.sympify(expression)
129
+ result = self.sp.simplify(expr)
130
+ return f"Simplified: {result}"
131
+ except Exception as e:
132
+ return f"Simplification error: {e}"
133
+
134
+ def _latex_formatter(self, args: Dict) -> str:
135
+ """Format as LaTeX"""
136
+ expression = args.get("expression", "")
137
+ inline = args.get("inline", False)
138
+
139
+ if not self.sp:
140
+ return "SymPy not available"
141
+
142
+ try:
143
+ expr = self.sp.sympify(expression)
144
+ latex = self.sp.latex(expr)
145
+
146
+ if inline:
147
+ return f"${latex}$"
148
+ else:
149
+ return f"$$\n{latex}\n$$"
150
+ except Exception as e:
151
+ return f"LaTeX formatting error: {e}"
152
+
153
+
154
+ class KirimMath:
155
+ """Kirim-1-Math inference with tool calling"""
156
+
157
+ def __init__(
158
+ self,
159
+ model_path: str = "Kirim-ai/Kirim-1-Math",
160
+ device: str = "auto",
161
+ load_in_8bit: bool = False,
162
+ load_in_4bit: bool = False
163
+ ):
164
+ print(f"Loading Kirim-1-Math from {model_path}...")
165
+
166
+ # Load tokenizer
167
+ self.tokenizer = AutoTokenizer.from_pretrained(
168
+ model_path,
169
+ trust_remote_code=True,
170
+ use_fast=True
171
+ )
172
+
173
+ # Configure model loading
174
+ model_kwargs = {
175
+ "trust_remote_code": True,
176
+ "torch_dtype": torch.bfloat16,
177
+ "low_cpu_mem_usage": True,
178
+ }
179
+
180
+ if load_in_8bit:
181
+ model_kwargs["load_in_8bit"] = True
182
+ print("Loading in 8-bit mode (30GB VRAM)")
183
+ elif load_in_4bit:
184
+ model_kwargs["load_in_4bit"] = True
185
+ print("Loading in 4-bit mode (20GB VRAM)")
186
+ else:
187
+ print("Loading in full precision (80GB VRAM)")
188
+
189
+ if device == "auto":
190
+ model_kwargs["device_map"] = "auto"
191
+
192
+ # Load model
193
+ self.model = AutoModelForCausalLM.from_pretrained(
194
+ model_path,
195
+ **model_kwargs
196
+ )
197
+
198
+ if device not in ["auto"] and not (load_in_8bit or load_in_4bit):
199
+ self.model = self.model.to(device)
200
+
201
+ self.model.eval()
202
+
203
+ # Initialize tool executor
204
+ self.tool_executor = MathToolExecutor()
205
+
206
+ print("✓ Model loaded successfully!")
207
+ print("✓ Tool calling enabled\n")
208
+
209
+ def solve_problem(
210
+ self,
211
+ problem: str,
212
+ show_work: bool = True,
213
+ use_tools: bool = True,
214
+ max_new_tokens: int = 4096,
215
+ temperature: float = 0.1
216
+ ) -> str:
217
+ """
218
+ Solve a mathematical problem
219
+
220
+ Args:
221
+ problem: Math problem to solve
222
+ show_work: Show step-by-step solution
223
+ use_tools: Enable tool calling
224
+ max_new_tokens: Maximum tokens to generate
225
+ temperature: Sampling temperature (lower = more deterministic)
226
+
227
+ Returns:
228
+ Solution with reasoning
229
+ """
230
+ # Construct prompt
231
+ system_prompt = "You are Kirim-1-Math, an advanced mathematical reasoning AI. "
232
+
233
+ if show_work:
234
+ system_prompt += "Show your work step-by-step. "
235
+
236
+ if use_tools:
237
+ system_prompt += "You can use tools for calculations. Available tools: calculator, symbolic_solver, derivative, integrate, simplify."
238
+
239
+ messages = [
240
+ {"role": "system", "content": system_prompt},
241
+ {"role": "user", "content": problem}
242
+ ]
243
+
244
+ # Generate initial response
245
+ response = self._generate(messages, max_new_tokens, temperature)
246
+
247
+ # Check for tool calls
248
+ if use_tools and "<tool_call>" in response:
249
+ response = self._handle_tool_calls(response, messages, max_new_tokens, temperature)
250
+
251
+ return response
252
+
253
+ def _generate(self, messages: List[Dict], max_new_tokens: int, temperature: float) -> str:
254
+ """Generate response from model"""
255
+ formatted_prompt = self.tokenizer.apply_chat_template(
256
+ messages,
257
+ tokenize=False,
258
+ add_generation_prompt=True
259
+ )
260
+
261
+ inputs = self.tokenizer(
262
+ formatted_prompt,
263
+ return_tensors="pt",
264
+ truncation=True,
265
+ max_length=28672
266
+ )
267
+
268
+ if hasattr(self.model, 'device'):
269
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
270
+
271
+ gen_kwargs = {
272
+ "max_new_tokens": max_new_tokens,
273
+ "temperature": temperature,
274
+ "top_p": 0.95,
275
+ "do_sample": temperature > 0,
276
+ "pad_token_id": self.tokenizer.pad_token_id,
277
+ "eos_token_id": self.tokenizer.eos_token_id,
278
+ }
279
+
280
+ with torch.no_grad():
281
+ outputs = self.model.generate(**inputs, **gen_kwargs)
282
+
283
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
284
+
285
+ # Extract assistant response
286
+ if "<|assistant|>" in full_response:
287
+ response = full_response.split("<|assistant|>")[-1]
288
+ response = response.replace("<|end_of_text|>", "").strip()
289
+ return response
290
+
291
+ return full_response.strip()
292
+
293
+ def _handle_tool_calls(self, response: str, messages: List[Dict], max_new_tokens: int, temperature: float) -> str:
294
+ """Process tool calls in response"""
295
+ # Extract tool calls
296
+ tool_pattern = r'<tool_call>(.*?)</tool_call>'
297
+ tool_calls = re.findall(tool_pattern, response, re.DOTALL)
298
+
299
+ if not tool_calls:
300
+ return response
301
+
302
+ # Execute each tool call
303
+ for tool_call_str in tool_calls:
304
+ try:
305
+ tool_call = json.loads(tool_call_str.strip())
306
+ tool_name = tool_call.get("name", "")
307
+ arguments = tool_call.get("arguments", {})
308
+
309
+ print(f"\n🔧 Executing tool: {tool_name}")
310
+ print(f" Arguments: {arguments}")
311
+
312
+ # Execute tool
313
+ result = self.tool_executor.execute_tool(tool_name, arguments)
314
+
315
+ print(f" Result: {result}\n")
316
+
317
+ # Add tool result to messages
318
+ messages.append({"role": "assistant", "content": response})
319
+ messages.append({"role": "tool", "content": f"<tool_result>{result}</tool_result>"})
320
+
321
+ # Generate continuation with tool result
322
+ response = self._generate(messages, max_new_tokens, temperature)
323
+
324
+ except json.JSONDecodeError:
325
+ print(f"⚠️ Failed to parse tool call: {tool_call_str}")
326
+ continue
327
+
328
+ return response
329
+
330
+ def interactive_math(self):
331
+ """Interactive math problem solver"""
332
+ print("\n" + "="*60)
333
+ print(" Kirim-1-Math - Interactive Mode")
334
+ print(" First model with tool calling!")
335
+ print("="*60)
336
+ print("\nCommands:")
337
+ print(" 'quit' or 'exit' - End session")
338
+ print(" 'tools off/on' - Toggle tool calling")
339
+ print(" 'work off/on' - Toggle showing work")
340
+ print("\n" + "="*60 + "\n")
341
+
342
+ use_tools = True
343
+ show_work = True
344
+
345
+ while True:
346
+ try:
347
+ user_input = input("Problem: ").strip()
348
+
349
+ if user_input.lower() in ['quit', 'exit', 'q']:
350
+ print("\nGoodbye! Happy solving! 🧮\n")
351
+ break
352
+
353
+ if user_input.lower().startswith('tools'):
354
+ use_tools = 'on' in user_input.lower()
355
+ print(f"✓ Tool calling: {'enabled' if use_tools else 'disabled'}\n")
356
+ continue
357
+
358
+ if user_input.lower().startswith('work'):
359
+ show_work = 'on' in user_input.lower()
360
+ print(f"✓ Show work: {'enabled' if show_work else 'disabled'}\n")
361
+ continue
362
+
363
+ if not user_input:
364
+ continue
365
+
366
+ # Solve problem
367
+ print("\n" + "-"*60)
368
+ solution = self.solve_problem(
369
+ user_input,
370
+ show_work=show_work,
371
+ use_tools=use_tools
372
+ )
373
+ print(solution)
374
+ print("-"*60 + "\n")
375
+
376
+ except KeyboardInterrupt:
377
+ print("\n\nGoodbye! 🧮\n")
378
+ break
379
+ except Exception as e:
380
+ print(f"\n❌ Error: {e}\n")
381
+
382
+
383
+ def main():
384
+ import argparse
385
+
386
+ parser = argparse.ArgumentParser(description="Kirim-1-Math Inference")
387
+ parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math")
388
+ parser.add_argument("--device", type=str, default="auto")
389
+ parser.add_argument("--load_in_8bit", action="store_true")
390
+ parser.add_argument("--load_in_4bit", action="store_true")
391
+ parser.add_argument("--interactive", action="store_true")
392
+ parser.add_argument("--problem", type=str, help="Single problem to solve")
393
+
394
+ args = parser.parse_args()
395
+
396
+ # Initialize model
397
+ kirim_math = KirimMath(
398
+ model_path=args.model_path,
399
+ device=args.device,
400
+ load_in_8bit=args.load_in_8bit,
401
+ load_in_4bit=args.load_in_4bit
402
+ )
403
+
404
+ if args.interactive:
405
+ kirim_math.interactive_math()
406
+ elif args.problem:
407
+ solution = kirim_math.solve_problem(args.problem)
408
+ print(f"\nProblem: {args.problem}")
409
+ print(f"\nSolution:\n{solution}\n")
410
+ else:
411
+ # Demo examples
412
+ print("="*60)
413
+ print(" Demo Examples")
414
+ print("="*60 + "\n")
415
+
416
+ demos = [
417
+ "Solve: x² - 5x + 6 = 0",
418
+ "Calculate the derivative of x³ + 2x² - x + 1",
419
+ "解方程: 2x + 3y = 12, 4x - y = 5",
420
+ "Integrate: ∫(x² + 1)dx"
421
+ ]
422
+
423
+ for problem in demos:
424
+ print(f"\nProblem: {problem}")
425
+ print("-" * 60)
426
+ solution = kirim_math.solve_problem(problem)
427
+ print(solution)
428
+ print("=" * 60)
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()