Spaces:
Sleeping
Sleeping
File size: 5,583 Bytes
7a5eef0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import os
# 模型配置
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
ADAPTER_ID = "zhman/llama-SFT-GRPO"
# 全局变量存储模型和tokenizer
model = None
tokenizer = None
def load_model():
"""加载模型和tokenizer"""
global model, tokenizer
print("正在加载模型...")
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# 加载 LoRA 适配器
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval()
print("模型加载完成!")
return model, tokenizer
def extract_boxed_answer(text):
"""提取 \\boxed{} 格式的答案"""
# 查找 \boxed{} 格式
boxed_pattern = r'\\boxed\{([^}]+)\}'
matches = re.findall(boxed_pattern, text)
if matches:
return matches[-1].strip()
# 尝试其他格式
patterns = [
r'答案[::]\s*([^\n]+)',
r'Answer[::]\s*([^\n]+)',
r'= *([^\n]+)',
r'因此[::]\s*([^\n]+)',
r'所以[::]\s*([^\n]+)',
]
for pattern in patterns:
matches = re.findall(pattern, text)
if matches:
return matches[-1].strip()
return None
def predict(question, max_new_tokens=1024, temperature=0.7):
"""
模型推理函数
Args:
question: 数学问题
max_new_tokens: 最大生成token数
temperature: 温度参数
Returns:
(完整输出, 提取的答案)
"""
global model, tokenizer
# 首次调用时加载模型
if model is None or tokenizer is None:
load_model()
# 构建prompt
prompt = f"User: {question}\nPlease reason step by step.\nAssistant:"
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# 生成
with torch.no_grad():
outputs = model.generate(
inputs.input_ids.to(model.device),
max_new_tokens=max_new_tokens,
temperature=min(temperature, 0.01) if temperature <= 0 else temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 解码
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_output.replace(prompt, "").strip()
# 提取答案
answer = extract_boxed_answer(response)
return response, answer if answer else "未能提取到答案"
# 创建 Gradio 界面
with gr.Blocks(title="数学问题求解 API") as demo:
gr.Markdown("# 🧮 数学问题求解 API 后端")
gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="数学问题",
placeholder="例如: 求解方程 x^2 + 5x + 6 = 0",
lines=5
)
with gr.Row():
max_tokens = gr.Slider(
minimum=128,
maximum=2048,
value=1024,
step=128,
label="最大生成长度"
)
temp = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
submit_btn = gr.Button("求解", variant="primary")
with gr.Column():
reasoning_output = gr.Textbox(
label="推理过程",
lines=15,
max_lines=20
)
answer_output = gr.Textbox(
label="提取的答案",
lines=2
)
# 示例
gr.Examples(
examples=[
["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."],
["求解方程 3×5 等于多少?"],
],
inputs=question_input
)
submit_btn.click(
fn=predict,
inputs=[question_input, max_tokens, temp],
outputs=[reasoning_output, answer_output],
api_name="predict" # 重要: 启用 API 访问
)
gr.Markdown("---")
gr.Markdown("### API 使用说明")
gr.Markdown("""
**API 端点**: `/api/predict`
**POST 请求示例**:
```python
import requests
response = requests.post(
"https://YOUR_SPACE_URL/api/predict",
json={
"data": [
"你的数学问题", # question
1024, # max_new_tokens
0.7 # temperature
]
}
)
result = response.json()
reasoning = result["data"][0] # 推理过程
answer = result["data"][1] # 提取的答案
```
""")
# 启动应用
if __name__ == "__main__":
demo.launch()
|