Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,136 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from transformers import
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# --- 配置 ---
|
| 7 |
+
# 您上传到Hub的仓库ID (基础模型 + LoRA适配器)
|
| 8 |
+
hub_repo_id = "yxccai/text-style"
|
| 9 |
+
# Qwen模型的基础模型名称 (与您微调时使用的基础模型一致)
|
| 10 |
+
# 例如: "Qwen/Qwen1.5-1.8B-Chat" 或 "Qwen/Qwen1.5-0.5B-Chat"
|
| 11 |
+
# 这个信息通常在您的LoRA适配器配置文件 (adapter_config.json) 中的 base_model_name_or_path 字段
|
| 12 |
+
# 您需要在这里明确指定它,因为我们要先加载基础模型
|
| 13 |
+
base_model_name = "Qwen/Qwen1.5-1.8B-Chat" # 假设您微调的是1.8B版本,请根据实际情况修改
|
| 14 |
+
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
print(f"Gradio App: Using device: {device}")
|
| 17 |
+
|
| 18 |
+
# --- 加载模型和Tokenizer ---
|
| 19 |
+
print(f"Gradio App: Loading base model: {base_model_name}")
|
| 20 |
+
# 1. 加载基础模型
|
| 21 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 22 |
+
base_model_name,
|
| 23 |
+
torch_dtype="auto", # 或者 torch.float16, torch.bfloat16
|
| 24 |
+
# device_map="auto", # 在Spaces中,直接 .to(device) 可能更稳定
|
| 25 |
+
trust_remote_code=True
|
| 26 |
+
# quantization_config=... # 如果基础模型加载时需要量化,这里也要配置
|
| 27 |
+
)
|
| 28 |
+
base_model.to(device)
|
| 29 |
+
|
| 30 |
+
print(f"Gradio App: Loading tokenizer from: {hub_repo_id}")
|
| 31 |
+
# 2. 加载Tokenizer (从您上传的仓库,它应该包含了基础模型的tokenizer配置)
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, trust_remote_code=True)
|
| 33 |
+
if tokenizer.pad_token is None:
|
| 34 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 35 |
+
base_model.config.pad_token_id = tokenizer.eos_token_id
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
print(f"Gradio App: Loading LoRA adapter from: {hub_repo_id}")
|
| 39 |
+
# 3. 加载并应用LoRA适配器
|
| 40 |
+
# hub_repo_id 指向的是包含LoRA适配器权重 (adapter_model.bin) 和配置 (adapter_config.json) 的仓库
|
| 41 |
+
model = PeftModel.from_pretrained(base_model, hub_repo_id)
|
| 42 |
+
# model.to(device) # base_model 已经 to(device) 了,PeftModel会继承
|
| 43 |
+
|
| 44 |
+
# (可选) 如果希望合并权重以简化,但会占用更多内存/磁盘
|
| 45 |
+
# print("Gradio App: Merging LoRA adapter...")
|
| 46 |
+
# model = model.merge_and_unload()
|
| 47 |
+
# print("Gradio App: LoRA adapter merged.")
|
| 48 |
+
|
| 49 |
+
model.eval() # 设置为评估模式
|
| 50 |
+
print("Gradio App: Model and tokenizer loaded successfully.")
|
| 51 |
+
|
| 52 |
+
# --- 推理函数 ---
|
| 53 |
+
def chat(input_text):
|
| 54 |
+
print(f"Gradio App: Received input: {input_text}")
|
| 55 |
+
# 构建符合Qwen Chat模板的输入
|
| 56 |
+
messages = [
|
| 57 |
+
{"role": "system", "content": "你是一个文本风格转换助手。请严格按照要求,仅将以下书面文本转换为自然、口语化的简洁表达方式,不要添加任何额外的解释、扩展信息或重复原文。"},
|
| 58 |
+
{"role": "user", "content": input_text}
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# 使用 apply_chat_template
|
| 62 |
+
# 注意:Hugging Face Spaces环境中的transformers版本可能与Colab不同
|
| 63 |
+
# 确保 apply_chat_template 的用法与您测试时一致
|
| 64 |
+
try:
|
| 65 |
+
prompt = tokenizer.apply_chat_template(
|
| 66 |
+
messages,
|
| 67 |
+
tokenize=False,
|
| 68 |
+
add_generation_prompt=True # 推理时需要模型知道何时开始生成
|
| 69 |
+
)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error applying chat template: {e}")
|
| 72 |
+
# 回退到一个简单的拼接方式,但这可能不是最优的
|
| 73 |
+
prompt = messages[0]["content"] + "\n" + messages[1]["content"] + "\n" + tokenizer.eos_token # 或者其他适合的格式
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
print(f"Gradio App: Formatted prompt for model:\n{prompt}")
|
| 77 |
+
|
| 78 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device) # 调整max_length
|
| 79 |
+
|
| 80 |
+
generated_ids = model.generate(
|
| 81 |
+
**inputs,
|
| 82 |
+
max_new_tokens=2048, # 控制输出长度
|
| 83 |
+
num_beams=1, # 可以尝试增加
|
| 84 |
+
do_sample=True,
|
| 85 |
+
temperature=0.7,
|
| 86 |
+
top_k=50,
|
| 87 |
+
top_p=0.95,
|
| 88 |
+
pad_token_id=tokenizer.eos_token_id
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# 解码生成的token IDs
|
| 92 |
+
# generated_ids[0] 包含了输入提示和模型生成的部分
|
| 93 |
+
full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False) # 保留特殊token以帮助分割
|
| 94 |
+
print(f"Gradio App: Full generated sequence:\n{full_generated_text}")
|
| 95 |
+
|
| 96 |
+
# 从完整序列中提取assistant的回复
|
| 97 |
+
assistant_marker_start = "<|im_start|>assistant" # Qwen的标记
|
| 98 |
+
|
| 99 |
+
if assistant_marker_start in full_generated_text:
|
| 100 |
+
parts = full_generated_text.split(assistant_marker_start)
|
| 101 |
+
if len(parts) > 1:
|
| 102 |
+
assistant_reply = parts[-1].strip()
|
| 103 |
+
# 移除可能的结束标记,如 <|im_end|> 或 eos_token
|
| 104 |
+
if assistant_reply.endswith(tokenizer.eos_token):
|
| 105 |
+
assistant_reply = assistant_reply[:-len(tokenizer.eos_token)].strip()
|
| 106 |
+
elif "<|im_end|>" in assistant_reply: # Qwen的聊天模板使用 <|im_end|>
|
| 107 |
+
assistant_reply = assistant_reply.split("<|im_end|>")[0].strip()
|
| 108 |
+
result = assistant_reply
|
| 109 |
+
else:
|
| 110 |
+
result = "模型未能生成assistant标记后的回复。"
|
| 111 |
+
else:
|
| 112 |
+
# 如果找不到 assistant 标记,尝试从原始prompt之后提取
|
| 113 |
+
# 这需要原始prompt的token数量
|
| 114 |
+
# 另一种简单方式是直接解码去除特殊token的生成部分,但这可能包含一些模板残留
|
| 115 |
+
result = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True).strip()
|
| 116 |
+
if not result: # 如果这种方式结果为空,可能解码时skip_special_tokens去除了所有
|
| 117 |
+
result = "模型输出格式不符合预期,未能提取有效回复。"
|
| 118 |
+
|
| 119 |
+
print(f"Gradio App: Extracted result: {result}")
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
# --- 创建Gradio界面 ---
|
| 123 |
+
iface = gr.Interface(
|
| 124 |
+
fn=chat,
|
| 125 |
+
inputs=gr.Textbox(lines=5, label="输入书面文本 (Input Formal Text)"),
|
| 126 |
+
outputs=gr.Textbox(lines=5, label="输出口语化文本 (Output Casual Text)"),
|
| 127 |
+
title="文本风格转换器 (Text Style Converter)",
|
| 128 |
+
description="输入一段书面化的中文文本,模型会尝试将其转换为更自然、口语化的表达方式。由Qwen模型微调。",
|
| 129 |
+
examples=[
|
| 130 |
+
["乙醇的检测方法包括以下几项: 1. 酸碱度检查:取20ml乙醇加20ml水,加2滴酚酞指示剂应无色,再加1ml 0.01mol/L氢氧化钠应显粉红色."],
|
| 131 |
+
["本公司今日发布了最新的财务业绩报告,数据显示本季度利润实现了显著增长。"]
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
iface.launch()
|