Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torch | |
| import re | |
| import os | |
| # 模型配置 | |
| MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" | |
| ADAPTER_ID = "zhman/llama-SFT-GRPO" | |
| # 全局变量存储模型和tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """加载模型和tokenizer""" | |
| global model, tokenizer | |
| print("正在加载模型...") | |
| # 加载tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 加载基础模型 | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # 加载 LoRA 适配器 | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_ID) | |
| model.eval() | |
| print("模型加载完成!") | |
| return model, tokenizer | |
| def extract_boxed_answer(text): | |
| """提取 \\boxed{} 格式的答案""" | |
| # 查找 \boxed{} 格式 | |
| boxed_pattern = r'\\boxed\{([^}]+)\}' | |
| matches = re.findall(boxed_pattern, text) | |
| if matches: | |
| return matches[-1].strip() | |
| # 尝试其他格式 | |
| patterns = [ | |
| r'答案[::]\s*([^\n]+)', | |
| r'Answer[::]\s*([^\n]+)', | |
| r'= *([^\n]+)', | |
| r'因此[::]\s*([^\n]+)', | |
| r'所以[::]\s*([^\n]+)', | |
| ] | |
| for pattern in patterns: | |
| matches = re.findall(pattern, text) | |
| if matches: | |
| return matches[-1].strip() | |
| return None | |
| def predict(question, max_new_tokens=1024, temperature=0.7): | |
| """ | |
| 模型推理函数 | |
| Args: | |
| question: 数学问题 | |
| max_new_tokens: 最大生成token数 | |
| temperature: 温度参数 | |
| Returns: | |
| (完整输出, 提取的答案) | |
| """ | |
| global model, tokenizer | |
| # 首次调用时加载模型 | |
| if model is None or tokenizer is None: | |
| load_model() | |
| # 构建prompt | |
| prompt = f"User: {question}\nPlease reason step by step.\nAssistant:" | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
| # 生成 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids.to(model.device), | |
| max_new_tokens=max_new_tokens, | |
| temperature=min(temperature, 0.01) if temperature <= 0 else temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # 解码 | |
| full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = full_output.replace(prompt, "").strip() | |
| # 提取答案 | |
| answer = extract_boxed_answer(response) | |
| return response, answer if answer else "未能提取到答案" | |
| # 创建 Gradio 界面 | |
| with gr.Blocks(title="数学问题求解 API") as demo: | |
| gr.Markdown("# 🧮 数学问题求解 API 后端") | |
| gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型") | |
| with gr.Row(): | |
| with gr.Column(): | |
| question_input = gr.Textbox( | |
| label="数学问题", | |
| placeholder="例如: 求解方程 x^2 + 5x + 6 = 0", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=2048, | |
| value=1024, | |
| step=128, | |
| label="最大生成长度" | |
| ) | |
| temp = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| submit_btn = gr.Button("求解", variant="primary") | |
| with gr.Column(): | |
| reasoning_output = gr.Textbox( | |
| label="推理过程", | |
| lines=15, | |
| max_lines=20 | |
| ) | |
| answer_output = gr.Textbox( | |
| label="提取的答案", | |
| lines=2 | |
| ) | |
| # 示例 | |
| gr.Examples( | |
| examples=[ | |
| ["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."], | |
| ["求解方程 3×5 等于多少?"], | |
| ], | |
| inputs=question_input | |
| ) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[question_input, max_tokens, temp], | |
| outputs=[reasoning_output, answer_output], | |
| api_name="predict" # 重要: 启用 API 访问 | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### API 使用说明") | |
| gr.Markdown(""" | |
| **API 端点**: `/api/predict` | |
| **POST 请求示例**: | |
| ```python | |
| import requests | |
| response = requests.post( | |
| "https://YOUR_SPACE_URL/api/predict", | |
| json={ | |
| "data": [ | |
| "你的数学问题", # question | |
| 1024, # max_new_tokens | |
| 0.7 # temperature | |
| ] | |
| } | |
| ) | |
| result = response.json() | |
| reasoning = result["data"][0] # 推理过程 | |
| answer = result["data"][1] # 提取的答案 | |
| ``` | |
| """) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch() | |