Spaces:
Sleeping
Sleeping
| # 这是一个测试注释,用来触发更改检测 | |
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| from peft import PeftModel # 引入 PeftModel 用于加载适配器 | |
| # 引入 accelerate 是个好习惯,即使 device_map='auto' 隐式使用它 | |
| import accelerate | |
| print("--- 程序开始 ---") | |
| # --- 配置区域 --- | |
| # 基础模型 ID (Gemma 2B) | |
| BASE_MODEL_ID = "google/gemma-2b" | |
| # PEFT 适配器模型 ID (你的微调模型) | |
| ADAPTER_MODEL_ID = "jinv2/safesky-ai-gemma-2b-sft" | |
| print(f"基础模型 ID: {BASE_MODEL_ID}") | |
| print(f"适配器 ID: {ADAPTER_MODEL_ID}") | |
| # --- 设备和数据类型检测 --- | |
| print("--- 开始检测设备和数据类型 ---") | |
| try: | |
| if torch.cuda.is_available(): | |
| # 使用 device_map='auto' 让 accelerate 自动处理多GPU或显存分配 | |
| device_map = "auto" | |
| # 记录设备类型用于界面显示 | |
| device_type = "GPU" | |
| print("检测到可用的 CUDA (GPU)。将使用 GPU 并自动分配设备 (device_map='auto')。") | |
| # 检查 GPU 是否支持 bfloat16 (通常需要 Ampere 架构或更新) | |
| # Gemma 模型使用 bfloat16 训练,优先使用 | |
| if torch.cuda.get_device_capability(0)[0] >= 8: # 检查第一个GPU的能力 | |
| torch_dtype = torch.bfloat16 | |
| print("GPU 支持 bfloat16,将使用 bfloat16 数据类型。") | |
| else: | |
| torch_dtype = torch.float16 # 不支持 bfloat16 时回退到 float16 | |
| print("GPU 不支持 bfloat16,将使用 float16 数据类型 (精度可能低于预期)。") | |
| else: | |
| device_type = "CPU" | |
| device_map = None # CPU 不需要 device_map | |
| torch_dtype = torch.float32 # CPU 通常使用 float32 | |
| print("未检测到 CUDA (GPU)。将使用 CPU。") | |
| # 注意这里的 except 块与上面的 try 配对 | |
| except Exception as e: | |
| print(f"检测设备或数据类型时出错: {e}。将默认使用 CPU 和 float32。") | |
| device_type = "CPU" | |
| device_map = None | |
| torch_dtype = torch.float32 | |
| print(f"最终选择的设备类型: {device_type}, 数据类型: {torch_dtype}, device_map: {device_map}") | |
| print("--- 设备检测完成 ---") | |
| # --- 定义模型加载和 Pipeline 创建的变量 --- | |
| # 将 pipe 初始化为 None,以便后续检查是否加载成功 | |
| pipe = None | |
| model = None | |
| tokenizer = None | |
| # --- 加载基础模型、适配器、Tokenizer 并创建 Pipeline --- | |
| # 将整个加载过程包裹在 try...except 中 | |
| print(f"--- 开始加载模型与 Pipeline ---") | |
| try: # <--- 第一个 try 块开始 | |
| print(f"开始加载基础模型: {BASE_MODEL_ID}...") | |
| # 1. 加载基础模型 | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map, | |
| trust_remote_code=True | |
| ) | |
| print("基础模型加载成功。") | |
| # 2. 加载 Tokenizer | |
| print(f"开始加载 Tokenizer: {ADAPTER_MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True) | |
| print("Tokenizer 加载成功。") | |
| # 3. 加载 PEFT 适配器并应用到基础模型上 | |
| print(f"开始加载 PEFT 适配器: {ADAPTER_MODEL_ID}...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) | |
| print("PEFT 适配器加载并应用成功。") | |
| # 4. 创建文本生成 Pipeline | |
| print("正在创建文本生成 Pipeline...") | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map # 传递 device_map | |
| ) | |
| print("模型、适配器和 Pipeline 全部加载成功。") | |
| # 注意这里的 except 块是和上面的 try 配对的 | |
| except Exception as e: # <--- 第一个 try 块的 except 部分 | |
| print(f"!!! 加载模型、适配器或 Pipeline 时发生严重错误: {e}") | |
| # 准备更详细的错误信息 | |
| error_message_detail = ( | |
| f"加载模型失败!\n\n" | |
| f"基础模型: '{BASE_MODEL_ID}'\n" | |
| f"适配器: '{ADAPTER_MODEL_ID}'\n\n" | |
| f"错误类型: {type(e).__name__}\n" | |
| f"错误详情: {e}\n\n" | |
| f"请检查 Space 的日志获取更详细的回溯信息,并确认硬件配置(推荐使用 GPU 运行 Gemma 2B)。" | |
| f"同时请确保模型 '{BASE_MODEL_ID}' 和适配器 '{ADAPTER_MODEL_ID}' 在 Hugging Face Hub 上存在且可访问。" | |
| ) | |
| print(error_message_detail) # 在日志中打印详细信息 | |
| # 直接抛出错误,阻止后续代码执行和界面渲染 | |
| raise gr.Error(error_message_detail) | |
| # <--- 第一个 try...except 结构在这里结束 | |
| # --- 文本生成函数 --- | |
| # 确保这个函数定义在 try...except 块之外,并且是顶层缩进 | |
| print("--- 定义文本生成函数 ---") | |
| def generate_text(prompt, max_new_tokens=256, temperature=0.7, top_p=0.95): | |
| """使用加载的 Pipeline 生成文本""" | |
| # 增加检查,确保 pipeline 已成功加载 | |
| if pipe is None: | |
| return ("错误:文本生成 Pipeline未能成功加载,无法生成文本。" | |
| "请检查 Space 日志了解详细加载错误。") | |
| if not prompt: | |
| return "请输入提示语。" | |
| print(f"\n收到提示语: '{prompt}'") | |
| print(f"生成参数: max_new_tokens={max_new_tokens}, temperature={temperature}, top_p={top_p}") | |
| # 应用聊天模板 | |
| messages = [{"role": "user", "content": prompt}] | |
| # 将模板应用也放入 try...except,以防模板本身或 tokenizer 出错 | |
| try: # <--- 第二个 try 块开始 (处理模板应用) | |
| formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| print(f"使用的格式化提示语 (请检查是否符合 Gemma 预期):\n{formatted_prompt}") | |
| except Exception as e: # <--- 第二个 try 块的 except 部分 | |
| print(f"应用聊天模板时出错: {e}") | |
| return f"处理输入格式时出错: {e}" | |
| # <--- 第二个 try...except 结构结束 | |
| # 将文本生成过程放入 try...except | |
| try: # <--- 第三个 try 块开始 (处理文本生成) | |
| # 生成文本 | |
| outputs = pipe( | |
| formatted_prompt, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=max(temperature, 0.01), | |
| top_p=top_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| num_return_sequences=1 | |
| ) | |
| generated_text = outputs[0]['generated_text'] | |
| print(f"\nPipeline 原始输出:\n{generated_text}") | |
| # --- 后处理:提取模型真实的回复 --- | |
| # 保证这部分逻辑的缩进正确 | |
| assistant_marker = "<start_of_turn>model\n" | |
| marker_index = generated_text.find(assistant_marker) | |
| if marker_index != -1: | |
| # 如果找到标记,则提取标记之后的内容 | |
| response = generated_text[marker_index + len(assistant_marker):].strip() | |
| # 移除可能存在的结束符 | |
| if response.endswith(tokenizer.eos_token): | |
| response = response[:-len(tokenizer.eos_token)].strip() | |
| else: # <--- 这个 else 和上面的 if marker_index != -1: 配对 | |
| # 回退方案:如果找不到标记 | |
| print("警告: 在输出中未找到 '<start_of_turn>model\n' 标记。尝试移除格式化提示语作为回退。") | |
| if generated_text.startswith(formatted_prompt): | |
| # 这下面的代码块属于 if generated_text.startswith(formatted_prompt): | |
| response = generated_text[len(formatted_prompt):].strip() | |
| if response.endswith(tokenizer.eos_token): | |
| response = response[:-len(tokenizer.eos_token)].strip() | |
| else: # <--- 这个 else 和 if generated_text.startswith(formatted_prompt): 配对 | |
| # ############################################################ # | |
| # ### 确保下面的 print 和 response = ... 紧跟在这个 else: 下方 ### # | |
| # ### 并且它们相对于这个 else: 有正确的缩进 (通常是4个空格) ### # | |
| # ############################################################ # | |
| print("警告: 回退方案也失败,可能无法准确提取模型回复。返回原始生成文本。") | |
| response = generated_text.strip() # <--- 这一行必须在 else 下面且缩进 | |
| # 这个 print 和 return 应该与 if marker_index != -1: 在同一缩进级别 | |
| print(f"提取的模型回复: {response}") | |
| return response | |
| except Exception as e: # <--- 第三个 try 块的 except 部分 | |
| print(f"!!! 生成过程中发生错误: {e}") | |
| # 返回错误信息给用户界面 | |
| return f"生成文本时发生错误,请检查日志。\n错误类型: {type(e).__name__}" | |
| # <--- 第三个 try...except 结构结束 | |
| # --- 创建 Gradio 界面 --- | |
| # 确保这部分代码也在所有 try...except 块之外,并且是顶层缩进 | |
| print("--- 开始创建 Gradio 界面 ---") | |
| with gr.Blocks(theme=gr.themes.Soft(), title=f"SafeSky Gemma SFT 测试") as demo: | |
| gr.Markdown(f""" | |
| # SafeSky AI Gemma 2B SFT ({ADAPTER_MODEL_ID}) 测试界面 | |
| 在下方输入你的提示语,并可以调整生成参数。 | |
| **运行环境:** 模型当前运行在 **{device_type}** 环境。推理速度可能会受硬件影响,尤其是在 CPU 上会比较慢。 | |
| **基础模型:** `{BASE_MODEL_ID}` | |
| """) | |
| # 增加一个状态显示,告知用户模型是否加载成功 | |
| load_status = "模型已成功加载并准备就绪。" if pipe is not None else "错误:模型未能加载,界面功能可能受限或无法使用。请检查日志!" | |
| gr.Markdown(f"**模型加载状态:** <span style='color: {'green' if pipe is not None else 'red'};'>{load_status}</span>") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="你的提示语", | |
| lines=5, | |
| placeholder="例如:写一个关于友好机器人的短故事。" | |
| ) | |
| with gr.Accordion("生成参数调整", open=False): | |
| max_new_tokens_slider = gr.Slider( | |
| minimum=10, maximum=1024, value=256, step=10, | |
| label="最大新词元数 (Max New Tokens)", | |
| info="控制生成文本的最大长度。" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.7, step=0.05, | |
| label="温度 (Temperature)", | |
| info="控制生成文本的随机性。值越低越确定,值越高越多样。" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
| label="Top-p (Nucleus Sampling)", | |
| info="从概率最高的词元中进行采样,累积概率不超过 P。" | |
| ) | |
| submit_button = gr.Button("生成文本", variant="primary") | |
| with gr.Column(scale=3): | |
| output_text = gr.Textbox( | |
| label="模型回复", | |
| lines=15, | |
| interactive=False # 输出框不可编辑 | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["写一个 Python 函数来计算一个数的阶乘。", 100, 0.7, 0.95], | |
| ["使用可再生能源有哪些主要好处?", 150, 0.8, 0.9], | |
| ["给我讲个关于电脑的笑话。", 50, 0.9, 0.95], | |
| ["简单解释一下“Safesky AI”的概念。", 200, 0.6, 0.9], | |
| ["中国的首都是哪里?", 30, 0.5, 0.9], | |
| ], | |
| inputs=[prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider], | |
| outputs=output_text, | |
| fn=generate_text, | |
| cache_examples=False # 禁用缓存,以便每次都重新生成 | |
| ) | |
| # 连接按钮点击事件到生成函数 | |
| submit_button.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider], | |
| outputs=output_text, | |
| api_name="generate" # 允许通过 API 调用 (可选) | |
| ) | |
| print("--- Gradio 界面定义完成 ---") | |
| # --- 启动 Gradio 应用 --- | |
| # 确保 demo.launch() 在所有定义之后,并且在全局作用域 (顶层缩进) | |
| print("--- 准备启动 Gradio 应用 ---") | |
| # 可以在启动前最后检查一次 pipe 是否加载成功 | |
| if pipe is None: | |
| print("!!! 警告: Pipeline未能加载,Gradio 应用可能无法正常工作。") | |
| # 你可以选择在这里也 raise 一个错误来彻底阻止启动 | |
| # raise RuntimeError("无法启动应用,因为模型 Pipeline 加载失败。") | |
| # 启动界面 | |
| demo.launch() # <--- 必须是顶层缩进,前面没有空格 | |
| print("--- 程序结束 (Gradio 服务器正在运行) ---") |