jinv2's picture
Update app.py
bb27837 verified
# 这是一个测试注释,用来触发更改检测
# -*- coding: utf-8 -*-
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
import os
from peft import PeftModel # 引入 PeftModel 用于加载适配器
# 引入 accelerate 是个好习惯,即使 device_map='auto' 隐式使用它
import accelerate
print("--- 程序开始 ---")
# --- 配置区域 ---
# 基础模型 ID (Gemma 2B)
BASE_MODEL_ID = "google/gemma-2b"
# PEFT 适配器模型 ID (你的微调模型)
ADAPTER_MODEL_ID = "jinv2/safesky-ai-gemma-2b-sft"
print(f"基础模型 ID: {BASE_MODEL_ID}")
print(f"适配器 ID: {ADAPTER_MODEL_ID}")
# --- 设备和数据类型检测 ---
print("--- 开始检测设备和数据类型 ---")
try:
if torch.cuda.is_available():
# 使用 device_map='auto' 让 accelerate 自动处理多GPU或显存分配
device_map = "auto"
# 记录设备类型用于界面显示
device_type = "GPU"
print("检测到可用的 CUDA (GPU)。将使用 GPU 并自动分配设备 (device_map='auto')。")
# 检查 GPU 是否支持 bfloat16 (通常需要 Ampere 架构或更新)
# Gemma 模型使用 bfloat16 训练,优先使用
if torch.cuda.get_device_capability(0)[0] >= 8: # 检查第一个GPU的能力
torch_dtype = torch.bfloat16
print("GPU 支持 bfloat16,将使用 bfloat16 数据类型。")
else:
torch_dtype = torch.float16 # 不支持 bfloat16 时回退到 float16
print("GPU 不支持 bfloat16,将使用 float16 数据类型 (精度可能低于预期)。")
else:
device_type = "CPU"
device_map = None # CPU 不需要 device_map
torch_dtype = torch.float32 # CPU 通常使用 float32
print("未检测到 CUDA (GPU)。将使用 CPU。")
# 注意这里的 except 块与上面的 try 配对
except Exception as e:
print(f"检测设备或数据类型时出错: {e}。将默认使用 CPU 和 float32。")
device_type = "CPU"
device_map = None
torch_dtype = torch.float32
print(f"最终选择的设备类型: {device_type}, 数据类型: {torch_dtype}, device_map: {device_map}")
print("--- 设备检测完成 ---")
# --- 定义模型加载和 Pipeline 创建的变量 ---
# 将 pipe 初始化为 None,以便后续检查是否加载成功
pipe = None
model = None
tokenizer = None
# --- 加载基础模型、适配器、Tokenizer 并创建 Pipeline ---
# 将整个加载过程包裹在 try...except 中
print(f"--- 开始加载模型与 Pipeline ---")
try: # <--- 第一个 try 块开始
print(f"开始加载基础模型: {BASE_MODEL_ID}...")
# 1. 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=True
)
print("基础模型加载成功。")
# 2. 加载 Tokenizer
print(f"开始加载 Tokenizer: {ADAPTER_MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
print("Tokenizer 加载成功。")
# 3. 加载 PEFT 适配器并应用到基础模型上
print(f"开始加载 PEFT 适配器: {ADAPTER_MODEL_ID}...")
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
print("PEFT 适配器加载并应用成功。")
# 4. 创建文本生成 Pipeline
print("正在创建文本生成 Pipeline...")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch_dtype,
device_map=device_map # 传递 device_map
)
print("模型、适配器和 Pipeline 全部加载成功。")
# 注意这里的 except 块是和上面的 try 配对的
except Exception as e: # <--- 第一个 try 块的 except 部分
print(f"!!! 加载模型、适配器或 Pipeline 时发生严重错误: {e}")
# 准备更详细的错误信息
error_message_detail = (
f"加载模型失败!\n\n"
f"基础模型: '{BASE_MODEL_ID}'\n"
f"适配器: '{ADAPTER_MODEL_ID}'\n\n"
f"错误类型: {type(e).__name__}\n"
f"错误详情: {e}\n\n"
f"请检查 Space 的日志获取更详细的回溯信息,并确认硬件配置(推荐使用 GPU 运行 Gemma 2B)。"
f"同时请确保模型 '{BASE_MODEL_ID}' 和适配器 '{ADAPTER_MODEL_ID}' 在 Hugging Face Hub 上存在且可访问。"
)
print(error_message_detail) # 在日志中打印详细信息
# 直接抛出错误,阻止后续代码执行和界面渲染
raise gr.Error(error_message_detail)
# <--- 第一个 try...except 结构在这里结束
# --- 文本生成函数 ---
# 确保这个函数定义在 try...except 块之外,并且是顶层缩进
print("--- 定义文本生成函数 ---")
def generate_text(prompt, max_new_tokens=256, temperature=0.7, top_p=0.95):
"""使用加载的 Pipeline 生成文本"""
# 增加检查,确保 pipeline 已成功加载
if pipe is None:
return ("错误:文本生成 Pipeline未能成功加载,无法生成文本。"
"请检查 Space 日志了解详细加载错误。")
if not prompt:
return "请输入提示语。"
print(f"\n收到提示语: '{prompt}'")
print(f"生成参数: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}")
# 应用聊天模板
messages = [{"role": "user", "content": prompt}]
# 将模板应用也放入 try...except,以防模板本身或 tokenizer 出错
try: # <--- 第二个 try 块开始 (处理模板应用)
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"使用的格式化提示语 (请检查是否符合 Gemma 预期):\n{formatted_prompt}")
except Exception as e: # <--- 第二个 try 块的 except 部分
print(f"应用聊天模板时出错: {e}")
return f"处理输入格式时出错: {e}"
# <--- 第二个 try...except 结构结束
# 将文本生成过程放入 try...except
try: # <--- 第三个 try 块开始 (处理文本生成)
# 生成文本
outputs = pipe(
formatted_prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=max(temperature, 0.01),
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
generated_text = outputs[0]['generated_text']
print(f"\nPipeline 原始输出:\n{generated_text}")
# --- 后处理:提取模型真实的回复 ---
# 保证这部分逻辑的缩进正确
assistant_marker = "<start_of_turn>model\n"
marker_index = generated_text.find(assistant_marker)
if marker_index != -1:
# 如果找到标记,则提取标记之后的内容
response = generated_text[marker_index + len(assistant_marker):].strip()
# 移除可能存在的结束符
if response.endswith(tokenizer.eos_token):
response = response[:-len(tokenizer.eos_token)].strip()
else: # <--- 这个 else 和上面的 if marker_index != -1: 配对
# 回退方案:如果找不到标记
print("警告: 在输出中未找到 '<start_of_turn>model\n' 标记。尝试移除格式化提示语作为回退。")
if generated_text.startswith(formatted_prompt):
# 这下面的代码块属于 if generated_text.startswith(formatted_prompt):
response = generated_text[len(formatted_prompt):].strip()
if response.endswith(tokenizer.eos_token):
response = response[:-len(tokenizer.eos_token)].strip()
else: # <--- 这个 else 和 if generated_text.startswith(formatted_prompt): 配对
# ############################################################ #
# ### 确保下面的 print 和 response = ... 紧跟在这个 else: 下方 ### #
# ### 并且它们相对于这个 else: 有正确的缩进 (通常是4个空格) ### #
# ############################################################ #
print("警告: 回退方案也失败,可能无法准确提取模型回复。返回原始生成文本。")
response = generated_text.strip() # <--- 这一行必须在 else 下面且缩进
# 这个 print 和 return 应该与 if marker_index != -1: 在同一缩进级别
print(f"提取的模型回复: {response}")
return response
except Exception as e: # <--- 第三个 try 块的 except 部分
print(f"!!! 生成过程中发生错误: {e}")
# 返回错误信息给用户界面
return f"生成文本时发生错误,请检查日志。\n错误类型: {type(e).__name__}"
# <--- 第三个 try...except 结构结束
# --- 创建 Gradio 界面 ---
# 确保这部分代码也在所有 try...except 块之外,并且是顶层缩进
print("--- 开始创建 Gradio 界面 ---")
with gr.Blocks(theme=gr.themes.Soft(), title=f"SafeSky Gemma SFT 测试") as demo:
gr.Markdown(f"""
# SafeSky AI Gemma 2B SFT ({ADAPTER_MODEL_ID}) 测试界面
在下方输入你的提示语,并可以调整生成参数。
**运行环境:** 模型当前运行在 **{device_type}** 环境。推理速度可能会受硬件影响,尤其是在 CPU 上会比较慢。
**基础模型:** `{BASE_MODEL_ID}`
""")
# 增加一个状态显示,告知用户模型是否加载成功
load_status = "模型已成功加载并准备就绪。" if pipe is not None else "错误:模型未能加载,界面功能可能受限或无法使用。请检查日志!"
gr.Markdown(f"**模型加载状态:** <span style='color: {'green' if pipe is not None else 'red'};'>{load_status}</span>")
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="你的提示语",
lines=5,
placeholder="例如:写一个关于友好机器人的短故事。"
)
with gr.Accordion("生成参数调整", open=False):
max_new_tokens_slider = gr.Slider(
minimum=10, maximum=1024, value=256, step=10,
label="最大新词元数 (Max New Tokens)",
info="控制生成文本的最大长度。"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=1.5, value=0.7, step=0.05,
label="温度 (Temperature)",
info="控制生成文本的随机性。值越低越确定,值越高越多样。"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-p (Nucleus Sampling)",
info="从概率最高的词元中进行采样,累积概率不超过 P。"
)
submit_button = gr.Button("生成文本", variant="primary")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="模型回复",
lines=15,
interactive=False # 输出框不可编辑
)
gr.Examples(
examples=[
["写一个 Python 函数来计算一个数的阶乘。", 100, 0.7, 0.95],
["使用可再生能源有哪些主要好处?", 150, 0.8, 0.9],
["给我讲个关于电脑的笑话。", 50, 0.9, 0.95],
["简单解释一下“Safesky AI”的概念。", 200, 0.6, 0.9],
["中国的首都是哪里?", 30, 0.5, 0.9],
],
inputs=[prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider],
outputs=output_text,
fn=generate_text,
cache_examples=False # 禁用缓存,以便每次都重新生成
)
# 连接按钮点击事件到生成函数
submit_button.click(
fn=generate_text,
inputs=[prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider],
outputs=output_text,
api_name="generate" # 允许通过 API 调用 (可选)
)
print("--- Gradio 界面定义完成 ---")
# --- 启动 Gradio 应用 ---
# 确保 demo.launch() 在所有定义之后,并且在全局作用域 (顶层缩进)
print("--- 准备启动 Gradio 应用 ---")
# 可以在启动前最后检查一次 pipe 是否加载成功
if pipe is None:
print("!!! 警告: Pipeline未能加载,Gradio 应用可能无法正常工作。")
# 你可以选择在这里也 raise 一个错误来彻底阻止启动
# raise RuntimeError("无法启动应用,因为模型 Pipeline 加载失败。")
# 启动界面
demo.launch() # <--- 必须是顶层缩进,前面没有空格
print("--- 程序结束 (Gradio 服务器正在运行) ---")