zhman commited on
Commit
7a5eef0
·
verified ·
1 Parent(s): cf5cb11

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +82 -13
  2. app.py +194 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,13 +1,82 @@
1
- ---
2
- title: Math Solver Api
3
- emoji: 💻
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver API
3
+ emoji: 🧮
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # 数学问题求解 API 后端
14
+
15
+ 这是一个基于 Gradio 的 API 后端服务,用于为数学问题求解提供模型推理能力。
16
+
17
+ ## 模型信息
18
+
19
+ - **基础模型**: meta-llama/Llama-3.2-1B-Instruct
20
+ - **微调适配器**: zhman/llama-SFT-GRPO
21
+ - **训练方法**: SFT + GRPO
22
+ - **准确率**: 97%
23
+
24
+ ## API 使用
25
+
26
+ ### 端点
27
+ `POST /api/predict`
28
+
29
+ ### 请求格式
30
+ ```json
31
+ {
32
+ "data": [
33
+ "你的数学问题",
34
+ 1024,
35
+ 0.7
36
+ ]
37
+ }
38
+ ```
39
+
40
+ ### 响应格式
41
+ ```json
42
+ {
43
+ "data": [
44
+ "推理过程...",
45
+ "提取的答案"
46
+ ]
47
+ }
48
+ ```
49
+
50
+ ### JavaScript 示例
51
+ ```javascript
52
+ const response = await fetch('https://YOUR_SPACE_URL/api/predict', {
53
+ method: 'POST',
54
+ headers: {
55
+ 'Content-Type': 'application/json'
56
+ },
57
+ body: JSON.stringify({
58
+ data: [
59
+ "Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer.",
60
+ 1024,
61
+ 0.7
62
+ ]
63
+ })
64
+ });
65
+
66
+ const result = await response.json();
67
+ console.log('推理过程:', result.data[0]);
68
+ console.log('答案:', result.data[1]);
69
+ ```
70
+
71
+ ## 部署说明
72
+
73
+ 1. 创建新的 Gradio Space
74
+ 2. 上传 `app.py` 和 `requirements.txt`
75
+ 3. 等待模型加载(首次约1-2分钟)
76
+ 4. Space URL 即为 API 基础地址
77
+
78
+ ## 注意事项
79
+
80
+ - Space 长时间不使用会休眠
81
+ - 休眠后首次调用会唤醒(约10-20秒)
82
+ - 推荐使用 GPU 硬件加速
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
+ import torch
5
+ import re
6
+ import os
7
+
8
+ # 模型配置
9
+ MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
10
+ ADAPTER_ID = "zhman/llama-SFT-GRPO"
11
+
12
+ # 全局变量存储模型和tokenizer
13
+ model = None
14
+ tokenizer = None
15
+
16
+ def load_model():
17
+ """加载模型和tokenizer"""
18
+ global model, tokenizer
19
+
20
+ print("正在加载模型...")
21
+
22
+ # 加载tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ # 加载基础模型
28
+ base_model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_ID,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto",
32
+ trust_remote_code=True
33
+ )
34
+
35
+ # 加载 LoRA 适配器
36
+ model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
37
+ model.eval()
38
+
39
+ print("模型加载完成!")
40
+ return model, tokenizer
41
+
42
+ def extract_boxed_answer(text):
43
+ """提取 \\boxed{} 格式的答案"""
44
+ # 查找 \boxed{} 格式
45
+ boxed_pattern = r'\\boxed\{([^}]+)\}'
46
+ matches = re.findall(boxed_pattern, text)
47
+
48
+ if matches:
49
+ return matches[-1].strip()
50
+
51
+ # 尝试其他格式
52
+ patterns = [
53
+ r'答案[::]\s*([^\n]+)',
54
+ r'Answer[::]\s*([^\n]+)',
55
+ r'= *([^\n]+)',
56
+ r'因此[::]\s*([^\n]+)',
57
+ r'所以[::]\s*([^\n]+)',
58
+ ]
59
+
60
+ for pattern in patterns:
61
+ matches = re.findall(pattern, text)
62
+ if matches:
63
+ return matches[-1].strip()
64
+
65
+ return None
66
+
67
+ def predict(question, max_new_tokens=1024, temperature=0.7):
68
+ """
69
+ 模型推理函数
70
+
71
+ Args:
72
+ question: 数学问题
73
+ max_new_tokens: 最大生成token数
74
+ temperature: 温度参数
75
+
76
+ Returns:
77
+ (完整输出, 提取的答案)
78
+ """
79
+ global model, tokenizer
80
+
81
+ # 首次调用时加载模型
82
+ if model is None or tokenizer is None:
83
+ load_model()
84
+
85
+ # 构建prompt
86
+ prompt = f"User: {question}\nPlease reason step by step.\nAssistant:"
87
+
88
+ # Tokenize
89
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
90
+
91
+ # 生成
92
+ with torch.no_grad():
93
+ outputs = model.generate(
94
+ inputs.input_ids.to(model.device),
95
+ max_new_tokens=max_new_tokens,
96
+ temperature=min(temperature, 0.01) if temperature <= 0 else temperature,
97
+ do_sample=temperature > 0,
98
+ pad_token_id=tokenizer.pad_token_id,
99
+ eos_token_id=tokenizer.eos_token_id
100
+ )
101
+
102
+ # 解码
103
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+ response = full_output.replace(prompt, "").strip()
105
+
106
+ # 提取答案
107
+ answer = extract_boxed_answer(response)
108
+
109
+ return response, answer if answer else "未能提取到答案"
110
+
111
+ # 创建 Gradio 界面
112
+ with gr.Blocks(title="数学问题求解 API") as demo:
113
+ gr.Markdown("# 🧮 数学问题求解 API 后端")
114
+ gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型")
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ question_input = gr.Textbox(
119
+ label="数学问题",
120
+ placeholder="例如: 求解方程 x^2 + 5x + 6 = 0",
121
+ lines=5
122
+ )
123
+ with gr.Row():
124
+ max_tokens = gr.Slider(
125
+ minimum=128,
126
+ maximum=2048,
127
+ value=1024,
128
+ step=128,
129
+ label="最大生成长度"
130
+ )
131
+ temp = gr.Slider(
132
+ minimum=0.1,
133
+ maximum=1.5,
134
+ value=0.7,
135
+ step=0.1,
136
+ label="Temperature"
137
+ )
138
+ submit_btn = gr.Button("求解", variant="primary")
139
+
140
+ with gr.Column():
141
+ reasoning_output = gr.Textbox(
142
+ label="推理过程",
143
+ lines=15,
144
+ max_lines=20
145
+ )
146
+ answer_output = gr.Textbox(
147
+ label="提取的答案",
148
+ lines=2
149
+ )
150
+
151
+ # 示例
152
+ gr.Examples(
153
+ examples=[
154
+ ["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."],
155
+ ["求解方程 3×5 等于多少?"],
156
+ ],
157
+ inputs=question_input
158
+ )
159
+
160
+ submit_btn.click(
161
+ fn=predict,
162
+ inputs=[question_input, max_tokens, temp],
163
+ outputs=[reasoning_output, answer_output],
164
+ api_name="predict" # 重要: 启用 API 访问
165
+ )
166
+
167
+ gr.Markdown("---")
168
+ gr.Markdown("### API 使用说明")
169
+ gr.Markdown("""
170
+ **API 端点**: `/api/predict`
171
+
172
+ **POST 请求示例**:
173
+ ```python
174
+ import requests
175
+
176
+ response = requests.post(
177
+ "https://YOUR_SPACE_URL/api/predict",
178
+ json={
179
+ "data": [
180
+ "你的数学问题", # question
181
+ 1024, # max_new_tokens
182
+ 0.7 # temperature
183
+ ]
184
+ }
185
+ )
186
+ result = response.json()
187
+ reasoning = result["data"][0] # 推理过程
188
+ answer = result["data"][1] # 提取的答案
189
+ ```
190
+ """)
191
+
192
+ # 启动应用
193
+ if __name__ == "__main__":
194
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.36.0
3
+ torch>=2.0.0
4
+ peft>=0.7.0
5
+ accelerate>=0.25.0
6
+ bitsandbytes>=0.41.0