|
|
""" |
|
|
Kirim-1-Math API Server |
|
|
FastAPI-based REST API for mathematical reasoning |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Optional, Any |
|
|
import uvicorn |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import json |
|
|
import logging |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
from inference_math import KirimMath, MathToolExecutor |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Kirim-1-Math API", |
|
|
description="Advanced Mathematical Reasoning API with Tool Calling", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
model_instance = None |
|
|
|
|
|
|
|
|
|
|
|
class MathProblemRequest(BaseModel): |
|
|
problem: str = Field(..., description="Mathematical problem to solve") |
|
|
show_work: bool = Field(True, description="Show step-by-step solution") |
|
|
use_tools: bool = Field(True, description="Enable tool calling") |
|
|
temperature: float = Field(0.1, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
max_tokens: int = Field(4096, ge=1, le=8192, description="Maximum tokens to generate") |
|
|
language: Optional[str] = Field("auto", description="Response language: 'auto', 'en', 'zh'") |
|
|
|
|
|
|
|
|
class ToolCallRequest(BaseModel): |
|
|
tool_name: str = Field(..., description="Name of the tool to call") |
|
|
arguments: Dict[str, Any] = Field(..., description="Tool arguments") |
|
|
|
|
|
|
|
|
class BatchMathRequest(BaseModel): |
|
|
problems: List[str] = Field(..., description="List of problems to solve") |
|
|
show_work: bool = Field(True, description="Show work for all problems") |
|
|
use_tools: bool = Field(True, description="Enable tool calling") |
|
|
temperature: float = Field(0.1, ge=0.0, le=2.0) |
|
|
|
|
|
|
|
|
class MathProblemResponse(BaseModel): |
|
|
problem: str |
|
|
solution: str |
|
|
tools_used: List[str] = [] |
|
|
execution_time_ms: float |
|
|
tokens_generated: int |
|
|
model: str = "Kirim-1-Math" |
|
|
|
|
|
|
|
|
class ToolCallResponse(BaseModel): |
|
|
tool_name: str |
|
|
result: str |
|
|
success: bool |
|
|
execution_time_ms: float |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
cuda_available: bool |
|
|
gpu_memory_used_gb: float |
|
|
gpu_memory_total_gb: float |
|
|
|
|
|
|
|
|
class ModelInfoResponse(BaseModel): |
|
|
model_name: str |
|
|
parameters: str |
|
|
capabilities: List[str] |
|
|
supported_tools: List[str] |
|
|
version: str |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Load the model on startup""" |
|
|
global model_instance |
|
|
|
|
|
try: |
|
|
logger.info("Loading Kirim-1-Math model...") |
|
|
model_instance = KirimMath( |
|
|
model_path="Kirim-ai/Kirim-1-Math", |
|
|
device="auto", |
|
|
load_in_4bit=False |
|
|
) |
|
|
logger.info("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Check API health and model status""" |
|
|
cuda_available = torch.cuda.is_available() |
|
|
|
|
|
if cuda_available: |
|
|
gpu_memory_allocated = torch.cuda.memory_allocated() / 1e9 |
|
|
gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
else: |
|
|
gpu_memory_allocated = 0 |
|
|
gpu_memory_total = 0 |
|
|
|
|
|
return HealthResponse( |
|
|
status="healthy" if model_instance else "model_not_loaded", |
|
|
model_loaded=model_instance is not None, |
|
|
cuda_available=cuda_available, |
|
|
gpu_memory_used_gb=round(gpu_memory_allocated, 2), |
|
|
gpu_memory_total_gb=round(gpu_memory_total, 2) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/info", response_model=ModelInfoResponse) |
|
|
async def model_info(): |
|
|
"""Get model information""" |
|
|
return ModelInfoResponse( |
|
|
model_name="Kirim-1-Math", |
|
|
parameters="30B", |
|
|
capabilities=[ |
|
|
"mathematical_reasoning", |
|
|
"tool_calling", |
|
|
"code_execution", |
|
|
"symbolic_computation", |
|
|
"bilingual (Chinese/English)" |
|
|
], |
|
|
supported_tools=[ |
|
|
"calculator", |
|
|
"symbolic_solver", |
|
|
"derivative", |
|
|
"integrate", |
|
|
"simplify", |
|
|
"latex_formatter", |
|
|
"code_executor" |
|
|
], |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/solve", response_model=MathProblemResponse) |
|
|
async def solve_problem(request: MathProblemRequest): |
|
|
"""Solve a mathematical problem""" |
|
|
if not model_instance: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
start_time = datetime.now() |
|
|
|
|
|
logger.info(f"Solving problem: {request.problem[:100]}...") |
|
|
|
|
|
solution = model_instance.solve_problem( |
|
|
problem=request.problem, |
|
|
show_work=request.show_work, |
|
|
use_tools=request.use_tools, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature |
|
|
) |
|
|
|
|
|
end_time = datetime.now() |
|
|
execution_time = (end_time - start_time).total_seconds() * 1000 |
|
|
|
|
|
|
|
|
tools_used = [] |
|
|
if "<tool_call>" in solution: |
|
|
|
|
|
import re |
|
|
tool_pattern = r'"name":\s*"([^"]+)"' |
|
|
tools_used = list(set(re.findall(tool_pattern, solution))) |
|
|
|
|
|
|
|
|
tokens_generated = len(solution.split()) * 1.3 |
|
|
|
|
|
return MathProblemResponse( |
|
|
problem=request.problem, |
|
|
solution=solution, |
|
|
tools_used=tools_used, |
|
|
execution_time_ms=round(execution_time, 2), |
|
|
tokens_generated=int(tokens_generated) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error solving problem: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/solve/batch") |
|
|
async def solve_batch(request: BatchMathRequest): |
|
|
"""Solve multiple problems in batch""" |
|
|
if not model_instance: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
results = [] |
|
|
|
|
|
for problem in request.problems: |
|
|
try: |
|
|
solution = model_instance.solve_problem( |
|
|
problem=problem, |
|
|
show_work=request.show_work, |
|
|
use_tools=request.use_tools, |
|
|
temperature=request.temperature |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
"problem": problem, |
|
|
"solution": solution, |
|
|
"success": True |
|
|
}) |
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"problem": problem, |
|
|
"solution": None, |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return {"results": results, "total": len(request.problems)} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/tools/call", response_model=ToolCallResponse) |
|
|
async def call_tool(request: ToolCallRequest): |
|
|
"""Directly call a mathematical tool""" |
|
|
try: |
|
|
start_time = datetime.now() |
|
|
|
|
|
tool_executor = MathToolExecutor() |
|
|
result = tool_executor.execute_tool(request.tool_name, request.arguments) |
|
|
|
|
|
end_time = datetime.now() |
|
|
execution_time = (end_time - start_time).total_seconds() * 1000 |
|
|
|
|
|
return ToolCallResponse( |
|
|
tool_name=request.tool_name, |
|
|
result=result, |
|
|
success="error" not in result.lower(), |
|
|
execution_time_ms=round(execution_time, 2) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return ToolCallResponse( |
|
|
tool_name=request.tool_name, |
|
|
result=str(e), |
|
|
success=False, |
|
|
execution_time_ms=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/tools/list") |
|
|
async def list_tools(): |
|
|
"""List all available mathematical tools""" |
|
|
tools = [ |
|
|
{ |
|
|
"name": "calculator", |
|
|
"description": "Perform precise arithmetic calculations", |
|
|
"parameters": ["expression", "precision"] |
|
|
}, |
|
|
{ |
|
|
"name": "symbolic_solver", |
|
|
"description": "Solve algebraic equations symbolically", |
|
|
"parameters": ["equation", "variable", "domain"] |
|
|
}, |
|
|
{ |
|
|
"name": "derivative", |
|
|
"description": "Calculate symbolic derivatives", |
|
|
"parameters": ["function", "variable", "order"] |
|
|
}, |
|
|
{ |
|
|
"name": "integrate", |
|
|
"description": "Calculate integrals", |
|
|
"parameters": ["function", "variable", "lower_bound", "upper_bound"] |
|
|
}, |
|
|
{ |
|
|
"name": "simplify", |
|
|
"description": "Simplify mathematical expressions", |
|
|
"parameters": ["expression", "method"] |
|
|
}, |
|
|
{ |
|
|
"name": "latex_formatter", |
|
|
"description": "Format expressions in LaTeX", |
|
|
"parameters": ["expression", "inline"] |
|
|
} |
|
|
] |
|
|
|
|
|
return {"tools": tools, "total": len(tools)} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/stats") |
|
|
async def get_stats(): |
|
|
"""Get API usage statistics""" |
|
|
|
|
|
return { |
|
|
"requests_processed": "N/A", |
|
|
"average_response_time_ms": "N/A", |
|
|
"model_status": "active" if model_instance else "inactive" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Kirim-1-Math API Server") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port number") |
|
|
parser.add_argument("--reload", action="store_true", help="Enable auto-reload") |
|
|
parser.add_argument("--workers", type=int, default=1, help="Number of workers") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
logger.info(f"Starting Kirim-1-Math API server on {args.host}:{args.port}") |
|
|
|
|
|
uvicorn.run( |
|
|
"api_server:app", |
|
|
host=args.host, |
|
|
port=args.port, |
|
|
reload=args.reload, |
|
|
workers=args.workers, |
|
|
log_level="info" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |