|
|
""" |
|
|
Kirim-1-Math Inference Script |
|
|
Mathematical reasoning with tool calling capabilities |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import re |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from typing import List, Dict, Any, Optional |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
class MathToolExecutor: |
|
|
"""Execute mathematical tools called by the model""" |
|
|
|
|
|
def __init__(self): |
|
|
try: |
|
|
import sympy as sp |
|
|
import numpy as np |
|
|
self.sp = sp |
|
|
self.np = np |
|
|
except ImportError: |
|
|
print("Warning: SymPy or NumPy not installed. Tool execution limited.") |
|
|
self.sp = None |
|
|
self.np = None |
|
|
|
|
|
def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: |
|
|
"""Execute a tool and return results""" |
|
|
try: |
|
|
if tool_name == "calculator": |
|
|
return self._calculator(arguments) |
|
|
elif tool_name == "symbolic_solver": |
|
|
return self._symbolic_solver(arguments) |
|
|
elif tool_name == "derivative": |
|
|
return self._derivative(arguments) |
|
|
elif tool_name == "integrate": |
|
|
return self._integrate(arguments) |
|
|
elif tool_name == "simplify": |
|
|
return self._simplify(arguments) |
|
|
elif tool_name == "latex_formatter": |
|
|
return self._latex_formatter(arguments) |
|
|
else: |
|
|
return f"Unknown tool: {tool_name}" |
|
|
except Exception as e: |
|
|
return f"Tool execution error: {str(e)}" |
|
|
|
|
|
def _calculator(self, args: Dict) -> str: |
|
|
"""Precise calculator""" |
|
|
expr = args.get("expression", "") |
|
|
precision = args.get("precision", 15) |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
result = self.sp.sympify(expr) |
|
|
result = self.sp.N(result, precision) |
|
|
return f"Result: {result}" |
|
|
except Exception as e: |
|
|
return f"Calculation error: {e}" |
|
|
|
|
|
def _symbolic_solver(self, args: Dict) -> str: |
|
|
"""Solve equations symbolically""" |
|
|
equation = args.get("equation", "") |
|
|
variable = args.get("variable", "x") |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
var = self.sp.Symbol(variable) |
|
|
eq = self.sp.sympify(equation) |
|
|
solutions = self.sp.solve(eq, var) |
|
|
return f"Solutions: {solutions}" |
|
|
except Exception as e: |
|
|
return f"Solver error: {e}" |
|
|
|
|
|
def _derivative(self, args: Dict) -> str: |
|
|
"""Calculate derivatives""" |
|
|
function = args.get("function", "") |
|
|
variable = args.get("variable", "x") |
|
|
order = args.get("order", 1) |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
var = self.sp.Symbol(variable) |
|
|
func = self.sp.sympify(function) |
|
|
result = self.sp.diff(func, var, order) |
|
|
return f"Derivative: {result}" |
|
|
except Exception as e: |
|
|
return f"Derivative error: {e}" |
|
|
|
|
|
def _integrate(self, args: Dict) -> str: |
|
|
"""Calculate integrals""" |
|
|
function = args.get("function", "") |
|
|
variable = args.get("variable", "x") |
|
|
lower = args.get("lower_bound") |
|
|
upper = args.get("upper_bound") |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
var = self.sp.Symbol(variable) |
|
|
func = self.sp.sympify(function) |
|
|
|
|
|
if lower is not None and upper is not None: |
|
|
result = self.sp.integrate(func, (var, lower, upper)) |
|
|
else: |
|
|
result = self.sp.integrate(func, var) |
|
|
|
|
|
return f"Integral: {result}" |
|
|
except Exception as e: |
|
|
return f"Integration error: {e}" |
|
|
|
|
|
def _simplify(self, args: Dict) -> str: |
|
|
"""Simplify expressions""" |
|
|
expression = args.get("expression", "") |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
expr = self.sp.sympify(expression) |
|
|
result = self.sp.simplify(expr) |
|
|
return f"Simplified: {result}" |
|
|
except Exception as e: |
|
|
return f"Simplification error: {e}" |
|
|
|
|
|
def _latex_formatter(self, args: Dict) -> str: |
|
|
"""Format as LaTeX""" |
|
|
expression = args.get("expression", "") |
|
|
inline = args.get("inline", False) |
|
|
|
|
|
if not self.sp: |
|
|
return "SymPy not available" |
|
|
|
|
|
try: |
|
|
expr = self.sp.sympify(expression) |
|
|
latex = self.sp.latex(expr) |
|
|
|
|
|
if inline: |
|
|
return f"${latex}$" |
|
|
else: |
|
|
return f"$$\n{latex}\n$$" |
|
|
except Exception as e: |
|
|
return f"LaTeX formatting error: {e}" |
|
|
|
|
|
|
|
|
class KirimMath: |
|
|
"""Kirim-1-Math inference with tool calling""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str = "Kirim-ai/Kirim-1-Math", |
|
|
device: str = "auto", |
|
|
load_in_8bit: bool = False, |
|
|
load_in_4bit: bool = False |
|
|
): |
|
|
print(f"Loading Kirim-1-Math from {model_path}...") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"torch_dtype": torch.bfloat16, |
|
|
"low_cpu_mem_usage": True, |
|
|
} |
|
|
|
|
|
if load_in_8bit: |
|
|
model_kwargs["load_in_8bit"] = True |
|
|
print("Loading in 8-bit mode (30GB VRAM)") |
|
|
elif load_in_4bit: |
|
|
model_kwargs["load_in_4bit"] = True |
|
|
print("Loading in 4-bit mode (20GB VRAM)") |
|
|
else: |
|
|
print("Loading in full precision (80GB VRAM)") |
|
|
|
|
|
if device == "auto": |
|
|
model_kwargs["device_map"] = "auto" |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
if device not in ["auto"] and not (load_in_8bit or load_in_4bit): |
|
|
self.model = self.model.to(device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.tool_executor = MathToolExecutor() |
|
|
|
|
|
print("✓ Model loaded successfully!") |
|
|
print("✓ Tool calling enabled\n") |
|
|
|
|
|
def solve_problem( |
|
|
self, |
|
|
problem: str, |
|
|
show_work: bool = True, |
|
|
use_tools: bool = True, |
|
|
max_new_tokens: int = 4096, |
|
|
temperature: float = 0.1 |
|
|
) -> str: |
|
|
""" |
|
|
Solve a mathematical problem |
|
|
|
|
|
Args: |
|
|
problem: Math problem to solve |
|
|
show_work: Show step-by-step solution |
|
|
use_tools: Enable tool calling |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature (lower = more deterministic) |
|
|
|
|
|
Returns: |
|
|
Solution with reasoning |
|
|
""" |
|
|
|
|
|
system_prompt = "You are Kirim-1-Math, an advanced mathematical reasoning AI. " |
|
|
|
|
|
if show_work: |
|
|
system_prompt += "Show your work step-by-step. " |
|
|
|
|
|
if use_tools: |
|
|
system_prompt += "You can use tools for calculations. Available tools: calculator, symbolic_solver, derivative, integrate, simplify." |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": problem} |
|
|
] |
|
|
|
|
|
|
|
|
response = self._generate(messages, max_new_tokens, temperature) |
|
|
|
|
|
|
|
|
if use_tools and "<tool_call>" in response: |
|
|
response = self._handle_tool_calls(response, messages, max_new_tokens, temperature) |
|
|
|
|
|
return response |
|
|
|
|
|
def _generate(self, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: |
|
|
"""Generate response from model""" |
|
|
formatted_prompt = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
formatted_prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=28672 |
|
|
) |
|
|
|
|
|
if hasattr(self.model, 'device'): |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": 0.95, |
|
|
"do_sample": temperature > 0, |
|
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if "<|assistant|>" in full_response: |
|
|
response = full_response.split("<|assistant|>")[-1] |
|
|
response = response.replace("<|end_of_text|>", "").strip() |
|
|
return response |
|
|
|
|
|
return full_response.strip() |
|
|
|
|
|
def _handle_tool_calls(self, response: str, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: |
|
|
"""Process tool calls in response""" |
|
|
|
|
|
tool_pattern = r'<tool_call>(.*?)</tool_call>' |
|
|
tool_calls = re.findall(tool_pattern, response, re.DOTALL) |
|
|
|
|
|
if not tool_calls: |
|
|
return response |
|
|
|
|
|
|
|
|
for tool_call_str in tool_calls: |
|
|
try: |
|
|
tool_call = json.loads(tool_call_str.strip()) |
|
|
tool_name = tool_call.get("name", "") |
|
|
arguments = tool_call.get("arguments", {}) |
|
|
|
|
|
print(f"\n🔧 Executing tool: {tool_name}") |
|
|
print(f" Arguments: {arguments}") |
|
|
|
|
|
|
|
|
result = self.tool_executor.execute_tool(tool_name, arguments) |
|
|
|
|
|
print(f" Result: {result}\n") |
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": response}) |
|
|
messages.append({"role": "tool", "content": f"<tool_result>{result}</tool_result>"}) |
|
|
|
|
|
|
|
|
response = self._generate(messages, max_new_tokens, temperature) |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
print(f"⚠️ Failed to parse tool call: {tool_call_str}") |
|
|
continue |
|
|
|
|
|
return response |
|
|
|
|
|
def interactive_math(self): |
|
|
"""Interactive math problem solver""" |
|
|
print("\n" + "="*60) |
|
|
print(" Kirim-1-Math - Interactive Mode") |
|
|
print(" First model with tool calling!") |
|
|
print("="*60) |
|
|
print("\nCommands:") |
|
|
print(" 'quit' or 'exit' - End session") |
|
|
print(" 'tools off/on' - Toggle tool calling") |
|
|
print(" 'work off/on' - Toggle showing work") |
|
|
print("\n" + "="*60 + "\n") |
|
|
|
|
|
use_tools = True |
|
|
show_work = True |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("Problem: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("\nGoodbye! Happy solving! 🧮\n") |
|
|
break |
|
|
|
|
|
if user_input.lower().startswith('tools'): |
|
|
use_tools = 'on' in user_input.lower() |
|
|
print(f"✓ Tool calling: {'enabled' if use_tools else 'disabled'}\n") |
|
|
continue |
|
|
|
|
|
if user_input.lower().startswith('work'): |
|
|
show_work = 'on' in user_input.lower() |
|
|
print(f"✓ Show work: {'enabled' if show_work else 'disabled'}\n") |
|
|
continue |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
print("\n" + "-"*60) |
|
|
solution = self.solve_problem( |
|
|
user_input, |
|
|
show_work=show_work, |
|
|
use_tools=use_tools |
|
|
) |
|
|
print(solution) |
|
|
print("-"*60 + "\n") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nGoodbye! 🧮\n") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"\n❌ Error: {e}\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Kirim-1-Math Inference") |
|
|
parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math") |
|
|
parser.add_argument("--device", type=str, default="auto") |
|
|
parser.add_argument("--load_in_8bit", action="store_true") |
|
|
parser.add_argument("--load_in_4bit", action="store_true") |
|
|
parser.add_argument("--interactive", action="store_true") |
|
|
parser.add_argument("--problem", type=str, help="Single problem to solve") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
kirim_math = KirimMath( |
|
|
model_path=args.model_path, |
|
|
device=args.device, |
|
|
load_in_8bit=args.load_in_8bit, |
|
|
load_in_4bit=args.load_in_4bit |
|
|
) |
|
|
|
|
|
if args.interactive: |
|
|
kirim_math.interactive_math() |
|
|
elif args.problem: |
|
|
solution = kirim_math.solve_problem(args.problem) |
|
|
print(f"\nProblem: {args.problem}") |
|
|
print(f"\nSolution:\n{solution}\n") |
|
|
else: |
|
|
|
|
|
print("="*60) |
|
|
print(" Demo Examples") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
demos = [ |
|
|
"Solve: x² - 5x + 6 = 0", |
|
|
"Calculate the derivative of x³ + 2x² - x + 1", |
|
|
"解方程: 2x + 3y = 12, 4x - y = 5", |
|
|
"Integrate: ∫(x² + 1)dx" |
|
|
] |
|
|
|
|
|
for problem in demos: |
|
|
print(f"\nProblem: {problem}") |
|
|
print("-" * 60) |
|
|
solution = kirim_math.solve_problem(problem) |
|
|
print(solution) |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |