math-solver-api / app.py
zhman's picture
Upload 3 files
7a5eef0 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import os
# 模型配置
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
ADAPTER_ID = "zhman/llama-SFT-GRPO"
# 全局变量存储模型和tokenizer
model = None
tokenizer = None
def load_model():
"""加载模型和tokenizer"""
global model, tokenizer
print("正在加载模型...")
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# 加载 LoRA 适配器
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval()
print("模型加载完成!")
return model, tokenizer
def extract_boxed_answer(text):
"""提取 \\boxed{} 格式的答案"""
# 查找 \boxed{} 格式
boxed_pattern = r'\\boxed\{([^}]+)\}'
matches = re.findall(boxed_pattern, text)
if matches:
return matches[-1].strip()
# 尝试其他格式
patterns = [
r'答案[::]\s*([^\n]+)',
r'Answer[::]\s*([^\n]+)',
r'= *([^\n]+)',
r'因此[::]\s*([^\n]+)',
r'所以[::]\s*([^\n]+)',
]
for pattern in patterns:
matches = re.findall(pattern, text)
if matches:
return matches[-1].strip()
return None
def predict(question, max_new_tokens=1024, temperature=0.7):
"""
模型推理函数
Args:
question: 数学问题
max_new_tokens: 最大生成token数
temperature: 温度参数
Returns:
(完整输出, 提取的答案)
"""
global model, tokenizer
# 首次调用时加载模型
if model is None or tokenizer is None:
load_model()
# 构建prompt
prompt = f"User: {question}\nPlease reason step by step.\nAssistant:"
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# 生成
with torch.no_grad():
outputs = model.generate(
inputs.input_ids.to(model.device),
max_new_tokens=max_new_tokens,
temperature=min(temperature, 0.01) if temperature <= 0 else temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 解码
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_output.replace(prompt, "").strip()
# 提取答案
answer = extract_boxed_answer(response)
return response, answer if answer else "未能提取到答案"
# 创建 Gradio 界面
with gr.Blocks(title="数学问题求解 API") as demo:
gr.Markdown("# 🧮 数学问题求解 API 后端")
gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="数学问题",
placeholder="例如: 求解方程 x^2 + 5x + 6 = 0",
lines=5
)
with gr.Row():
max_tokens = gr.Slider(
minimum=128,
maximum=2048,
value=1024,
step=128,
label="最大生成长度"
)
temp = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
submit_btn = gr.Button("求解", variant="primary")
with gr.Column():
reasoning_output = gr.Textbox(
label="推理过程",
lines=15,
max_lines=20
)
answer_output = gr.Textbox(
label="提取的答案",
lines=2
)
# 示例
gr.Examples(
examples=[
["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."],
["求解方程 3×5 等于多少?"],
],
inputs=question_input
)
submit_btn.click(
fn=predict,
inputs=[question_input, max_tokens, temp],
outputs=[reasoning_output, answer_output],
api_name="predict" # 重要: 启用 API 访问
)
gr.Markdown("---")
gr.Markdown("### API 使用说明")
gr.Markdown("""
**API 端点**: `/api/predict`
**POST 请求示例**:
```python
import requests
response = requests.post(
"https://YOUR_SPACE_URL/api/predict",
json={
"data": [
"你的数学问题", # question
1024, # max_new_tokens
0.7 # temperature
]
}
)
result = response.json()
reasoning = result["data"][0] # 推理过程
answer = result["data"][1] # 提取的答案
```
""")
# 启动应用
if __name__ == "__main__":
demo.launch()