Kirim-1-Math / inference_math.py
Kirim1's picture
Create inference_math.py
c913ed4 verified
"""
Kirim-1-Math Inference Script
Mathematical reasoning with tool calling capabilities
"""
import torch
import json
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')
class MathToolExecutor:
"""Execute mathematical tools called by the model"""
def __init__(self):
try:
import sympy as sp
import numpy as np
self.sp = sp
self.np = np
except ImportError:
print("Warning: SymPy or NumPy not installed. Tool execution limited.")
self.sp = None
self.np = None
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""Execute a tool and return results"""
try:
if tool_name == "calculator":
return self._calculator(arguments)
elif tool_name == "symbolic_solver":
return self._symbolic_solver(arguments)
elif tool_name == "derivative":
return self._derivative(arguments)
elif tool_name == "integrate":
return self._integrate(arguments)
elif tool_name == "simplify":
return self._simplify(arguments)
elif tool_name == "latex_formatter":
return self._latex_formatter(arguments)
else:
return f"Unknown tool: {tool_name}"
except Exception as e:
return f"Tool execution error: {str(e)}"
def _calculator(self, args: Dict) -> str:
"""Precise calculator"""
expr = args.get("expression", "")
precision = args.get("precision", 15)
if not self.sp:
return "SymPy not available"
try:
result = self.sp.sympify(expr)
result = self.sp.N(result, precision)
return f"Result: {result}"
except Exception as e:
return f"Calculation error: {e}"
def _symbolic_solver(self, args: Dict) -> str:
"""Solve equations symbolically"""
equation = args.get("equation", "")
variable = args.get("variable", "x")
if not self.sp:
return "SymPy not available"
try:
var = self.sp.Symbol(variable)
eq = self.sp.sympify(equation)
solutions = self.sp.solve(eq, var)
return f"Solutions: {solutions}"
except Exception as e:
return f"Solver error: {e}"
def _derivative(self, args: Dict) -> str:
"""Calculate derivatives"""
function = args.get("function", "")
variable = args.get("variable", "x")
order = args.get("order", 1)
if not self.sp:
return "SymPy not available"
try:
var = self.sp.Symbol(variable)
func = self.sp.sympify(function)
result = self.sp.diff(func, var, order)
return f"Derivative: {result}"
except Exception as e:
return f"Derivative error: {e}"
def _integrate(self, args: Dict) -> str:
"""Calculate integrals"""
function = args.get("function", "")
variable = args.get("variable", "x")
lower = args.get("lower_bound")
upper = args.get("upper_bound")
if not self.sp:
return "SymPy not available"
try:
var = self.sp.Symbol(variable)
func = self.sp.sympify(function)
if lower is not None and upper is not None:
result = self.sp.integrate(func, (var, lower, upper))
else:
result = self.sp.integrate(func, var)
return f"Integral: {result}"
except Exception as e:
return f"Integration error: {e}"
def _simplify(self, args: Dict) -> str:
"""Simplify expressions"""
expression = args.get("expression", "")
if not self.sp:
return "SymPy not available"
try:
expr = self.sp.sympify(expression)
result = self.sp.simplify(expr)
return f"Simplified: {result}"
except Exception as e:
return f"Simplification error: {e}"
def _latex_formatter(self, args: Dict) -> str:
"""Format as LaTeX"""
expression = args.get("expression", "")
inline = args.get("inline", False)
if not self.sp:
return "SymPy not available"
try:
expr = self.sp.sympify(expression)
latex = self.sp.latex(expr)
if inline:
return f"${latex}$"
else:
return f"$$\n{latex}\n$$"
except Exception as e:
return f"LaTeX formatting error: {e}"
class KirimMath:
"""Kirim-1-Math inference with tool calling"""
def __init__(
self,
model_path: str = "Kirim-ai/Kirim-1-Math",
device: str = "auto",
load_in_8bit: bool = False,
load_in_4bit: bool = False
):
print(f"Loading Kirim-1-Math from {model_path}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
use_fast=True
)
# Configure model loading
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.bfloat16,
"low_cpu_mem_usage": True,
}
if load_in_8bit:
model_kwargs["load_in_8bit"] = True
print("Loading in 8-bit mode (30GB VRAM)")
elif load_in_4bit:
model_kwargs["load_in_4bit"] = True
print("Loading in 4-bit mode (20GB VRAM)")
else:
print("Loading in full precision (80GB VRAM)")
if device == "auto":
model_kwargs["device_map"] = "auto"
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
**model_kwargs
)
if device not in ["auto"] and not (load_in_8bit or load_in_4bit):
self.model = self.model.to(device)
self.model.eval()
# Initialize tool executor
self.tool_executor = MathToolExecutor()
print("✓ Model loaded successfully!")
print("✓ Tool calling enabled\n")
def solve_problem(
self,
problem: str,
show_work: bool = True,
use_tools: bool = True,
max_new_tokens: int = 4096,
temperature: float = 0.1
) -> str:
"""
Solve a mathematical problem
Args:
problem: Math problem to solve
show_work: Show step-by-step solution
use_tools: Enable tool calling
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (lower = more deterministic)
Returns:
Solution with reasoning
"""
# Construct prompt
system_prompt = "You are Kirim-1-Math, an advanced mathematical reasoning AI. "
if show_work:
system_prompt += "Show your work step-by-step. "
if use_tools:
system_prompt += "You can use tools for calculations. Available tools: calculator, symbolic_solver, derivative, integrate, simplify."
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": problem}
]
# Generate initial response
response = self._generate(messages, max_new_tokens, temperature)
# Check for tool calls
if use_tools and "<tool_call>" in response:
response = self._handle_tool_calls(response, messages, max_new_tokens, temperature)
return response
def _generate(self, messages: List[Dict], max_new_tokens: int, temperature: float) -> str:
"""Generate response from model"""
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=28672
)
if hasattr(self.model, 'device'):
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": 0.95,
"do_sample": temperature > 0,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract assistant response
if "<|assistant|>" in full_response:
response = full_response.split("<|assistant|>")[-1]
response = response.replace("<|end_of_text|>", "").strip()
return response
return full_response.strip()
def _handle_tool_calls(self, response: str, messages: List[Dict], max_new_tokens: int, temperature: float) -> str:
"""Process tool calls in response"""
# Extract tool calls
tool_pattern = r'<tool_call>(.*?)</tool_call>'
tool_calls = re.findall(tool_pattern, response, re.DOTALL)
if not tool_calls:
return response
# Execute each tool call
for tool_call_str in tool_calls:
try:
tool_call = json.loads(tool_call_str.strip())
tool_name = tool_call.get("name", "")
arguments = tool_call.get("arguments", {})
print(f"\n🔧 Executing tool: {tool_name}")
print(f" Arguments: {arguments}")
# Execute tool
result = self.tool_executor.execute_tool(tool_name, arguments)
print(f" Result: {result}\n")
# Add tool result to messages
messages.append({"role": "assistant", "content": response})
messages.append({"role": "tool", "content": f"<tool_result>{result}</tool_result>"})
# Generate continuation with tool result
response = self._generate(messages, max_new_tokens, temperature)
except json.JSONDecodeError:
print(f"⚠️ Failed to parse tool call: {tool_call_str}")
continue
return response
def interactive_math(self):
"""Interactive math problem solver"""
print("\n" + "="*60)
print(" Kirim-1-Math - Interactive Mode")
print(" First model with tool calling!")
print("="*60)
print("\nCommands:")
print(" 'quit' or 'exit' - End session")
print(" 'tools off/on' - Toggle tool calling")
print(" 'work off/on' - Toggle showing work")
print("\n" + "="*60 + "\n")
use_tools = True
show_work = True
while True:
try:
user_input = input("Problem: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("\nGoodbye! Happy solving! 🧮\n")
break
if user_input.lower().startswith('tools'):
use_tools = 'on' in user_input.lower()
print(f"✓ Tool calling: {'enabled' if use_tools else 'disabled'}\n")
continue
if user_input.lower().startswith('work'):
show_work = 'on' in user_input.lower()
print(f"✓ Show work: {'enabled' if show_work else 'disabled'}\n")
continue
if not user_input:
continue
# Solve problem
print("\n" + "-"*60)
solution = self.solve_problem(
user_input,
show_work=show_work,
use_tools=use_tools
)
print(solution)
print("-"*60 + "\n")
except KeyboardInterrupt:
print("\n\nGoodbye! 🧮\n")
break
except Exception as e:
print(f"\n❌ Error: {e}\n")
def main():
import argparse
parser = argparse.ArgumentParser(description="Kirim-1-Math Inference")
parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math")
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--load_in_8bit", action="store_true")
parser.add_argument("--load_in_4bit", action="store_true")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--problem", type=str, help="Single problem to solve")
args = parser.parse_args()
# Initialize model
kirim_math = KirimMath(
model_path=args.model_path,
device=args.device,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit
)
if args.interactive:
kirim_math.interactive_math()
elif args.problem:
solution = kirim_math.solve_problem(args.problem)
print(f"\nProblem: {args.problem}")
print(f"\nSolution:\n{solution}\n")
else:
# Demo examples
print("="*60)
print(" Demo Examples")
print("="*60 + "\n")
demos = [
"Solve: x² - 5x + 6 = 0",
"Calculate the derivative of x³ + 2x² - x + 1",
"解方程: 2x + 3y = 12, 4x - y = 5",
"Integrate: ∫(x² + 1)dx"
]
for problem in demos:
print(f"\nProblem: {problem}")
print("-" * 60)
solution = kirim_math.solve_problem(problem)
print(solution)
print("=" * 60)
if __name__ == "__main__":
main()