File size: 5,583 Bytes
7a5eef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import os

# 模型配置
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
ADAPTER_ID = "zhman/llama-SFT-GRPO"

# 全局变量存储模型和tokenizer
model = None
tokenizer = None

def load_model():
    """加载模型和tokenizer"""
    global model, tokenizer
    
    print("正在加载模型...")
    
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 加载基础模型
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # 加载 LoRA 适配器
    model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
    model.eval()
    
    print("模型加载完成!")
    return model, tokenizer

def extract_boxed_answer(text):
    """提取 \\boxed{} 格式的答案"""
    # 查找 \boxed{} 格式
    boxed_pattern = r'\\boxed\{([^}]+)\}'
    matches = re.findall(boxed_pattern, text)
    
    if matches:
        return matches[-1].strip()
    
    # 尝试其他格式
    patterns = [
        r'答案[::]\s*([^\n]+)',
        r'Answer[::]\s*([^\n]+)',
        r'= *([^\n]+)',
        r'因此[::]\s*([^\n]+)',
        r'所以[::]\s*([^\n]+)',
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, text)
        if matches:
            return matches[-1].strip()
    
    return None

def predict(question, max_new_tokens=1024, temperature=0.7):
    """

    模型推理函数

    

    Args:

        question: 数学问题

        max_new_tokens: 最大生成token数

        temperature: 温度参数

    

    Returns:

        (完整输出, 提取的答案)

    """
    global model, tokenizer
    
    # 首次调用时加载模型
    if model is None or tokenizer is None:
        load_model()
    
    # 构建prompt
    prompt = f"User: {question}\nPlease reason step by step.\nAssistant:"
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    
    # 生成
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids.to(model.device),
            max_new_tokens=max_new_tokens,
            temperature=min(temperature, 0.01) if temperature <= 0 else temperature,
            do_sample=temperature > 0,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # 解码
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_output.replace(prompt, "").strip()
    
    # 提取答案
    answer = extract_boxed_answer(response)
    
    return response, answer if answer else "未能提取到答案"

# 创建 Gradio 界面
with gr.Blocks(title="数学问题求解 API") as demo:
    gr.Markdown("# 🧮 数学问题求解 API 后端")
    gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型")
    
    with gr.Row():
        with gr.Column():
            question_input = gr.Textbox(
                label="数学问题",
                placeholder="例如: 求解方程 x^2 + 5x + 6 = 0",
                lines=5
            )
            with gr.Row():
                max_tokens = gr.Slider(
                    minimum=128,
                    maximum=2048,
                    value=1024,
                    step=128,
                    label="最大生成长度"
                )
                temp = gr.Slider(
                    minimum=0.1,
                    maximum=1.5,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
            submit_btn = gr.Button("求解", variant="primary")
        
        with gr.Column():
            reasoning_output = gr.Textbox(
                label="推理过程",
                lines=15,
                max_lines=20
            )
            answer_output = gr.Textbox(
                label="提取的答案",
                lines=2
            )
    
    # 示例
    gr.Examples(
        examples=[
            ["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."],
            ["求解方程 3×5 等于多少?"],
        ],
        inputs=question_input
    )
    
    submit_btn.click(
        fn=predict,
        inputs=[question_input, max_tokens, temp],
        outputs=[reasoning_output, answer_output],
        api_name="predict"  # 重要: 启用 API 访问
    )
    
    gr.Markdown("---")
    gr.Markdown("### API 使用说明")
    gr.Markdown("""

    **API 端点**: `/api/predict`

    

    **POST 请求示例**:

    ```python

    import requests

    

    response = requests.post(

        "https://YOUR_SPACE_URL/api/predict",

        json={

            "data": [

                "你的数学问题",  # question

                1024,           # max_new_tokens

                0.7             # temperature

            ]

        }

    )

    result = response.json()

    reasoning = result["data"][0]  # 推理过程

    answer = result["data"][1]     # 提取的答案

    ```

    """)

# 启动应用
if __name__ == "__main__":
    demo.launch()