import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torch import re import os # 模型配置 MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" ADAPTER_ID = "zhman/llama-SFT-GRPO" # 全局变量存储模型和tokenizer model = None tokenizer = None def load_model(): """加载模型和tokenizer""" global model, tokenizer print("正在加载模型...") # 加载tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 加载基础模型 base_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) # 加载 LoRA 适配器 model = PeftModel.from_pretrained(base_model, ADAPTER_ID) model.eval() print("模型加载完成!") return model, tokenizer def extract_boxed_answer(text): """提取 \\boxed{} 格式的答案""" # 查找 \boxed{} 格式 boxed_pattern = r'\\boxed\{([^}]+)\}' matches = re.findall(boxed_pattern, text) if matches: return matches[-1].strip() # 尝试其他格式 patterns = [ r'答案[::]\s*([^\n]+)', r'Answer[::]\s*([^\n]+)', r'= *([^\n]+)', r'因此[::]\s*([^\n]+)', r'所以[::]\s*([^\n]+)', ] for pattern in patterns: matches = re.findall(pattern, text) if matches: return matches[-1].strip() return None def predict(question, max_new_tokens=1024, temperature=0.7): """ 模型推理函数 Args: question: 数学问题 max_new_tokens: 最大生成token数 temperature: 温度参数 Returns: (完整输出, 提取的答案) """ global model, tokenizer # 首次调用时加载模型 if model is None or tokenizer is None: load_model() # 构建prompt prompt = f"User: {question}\nPlease reason step by step.\nAssistant:" # Tokenize inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) # 生成 with torch.no_grad(): outputs = model.generate( inputs.input_ids.to(model.device), max_new_tokens=max_new_tokens, temperature=min(temperature, 0.01) if temperature <= 0 else temperature, do_sample=temperature > 0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # 解码 full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_output.replace(prompt, "").strip() # 提取答案 answer = extract_boxed_answer(response) return response, answer if answer else "未能提取到答案" # 创建 Gradio 界面 with gr.Blocks(title="数学问题求解 API") as demo: gr.Markdown("# 🧮 数学问题求解 API 后端") gr.Markdown("基于 Llama-3.2-1B + SFT + GRPO 微调模型") with gr.Row(): with gr.Column(): question_input = gr.Textbox( label="数学问题", placeholder="例如: 求解方程 x^2 + 5x + 6 = 0", lines=5 ) with gr.Row(): max_tokens = gr.Slider( minimum=128, maximum=2048, value=1024, step=128, label="最大生成长度" ) temp = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature" ) submit_btn = gr.Button("求解", variant="primary") with gr.Column(): reasoning_output = gr.Textbox( label="推理过程", lines=15, max_lines=20 ) answer_output = gr.Textbox( label="提取的答案", lines=2 ) # 示例 gr.Examples( examples=[ ["Find the positive integer n such that 10^n cubic centimeters is the same as 1 cubic kilometer."], ["求解方程 3×5 等于多少?"], ], inputs=question_input ) submit_btn.click( fn=predict, inputs=[question_input, max_tokens, temp], outputs=[reasoning_output, answer_output], api_name="predict" # 重要: 启用 API 访问 ) gr.Markdown("---") gr.Markdown("### API 使用说明") gr.Markdown(""" **API 端点**: `/api/predict` **POST 请求示例**: ```python import requests response = requests.post( "https://YOUR_SPACE_URL/api/predict", json={ "data": [ "你的数学问题", # question 1024, # max_new_tokens 0.7 # temperature ] } ) result = response.json() reasoning = result["data"][0] # 推理过程 answer = result["data"][1] # 提取的答案 ``` """) # 启动应用 if __name__ == "__main__": demo.launch()