Spaces:
Runtime error
Runtime error
File size: 9,554 Bytes
aa6e6b1 714c472 aa6e6b1 cfb21b4 e219ec6 62b1e28 aa6e6b1 714c472 cfb21b4 aa6e6b1 bd177b1 aa6e6b1 bd177b1 aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 ccae5de aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 e1e36f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
# app.py
# Gradio UI for PromptEnhancerV2
import os
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer
import time
import logging
import re
import torch
import gradio as gr
import spaces
# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
from qwen_vl_utils import process_vision_info
except Exception:
def process_vision_info(messages):
return None, None
def replace_single_quotes(text):
pattern = r"\B'([^']*)'\B"
replaced_text = re.sub(pattern, r'"\1"', text)
replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
return replaced_text
class PromptEnhancerV2:
@spaces.GPU
def __init__(self, models_root_path, device_map="auto", torch_dtype="bfloat16"):#auto
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
if not logging.getLogger(__name__).handlers:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# dtype 兼容处理
if torch_dtype == "bfloat16":
dtype = torch.bfloat16
elif torch_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
models_root_path,
torch_dtype=dtype,
device_map=device_map,
)
self.processor = AutoProcessor.from_pretrained(models_root_path)
# @torch.inference_mode()
@spaces.GPU
def predict(
self,
prompt_cot,
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
temperature=0.1,
top_p=1.0,
max_new_tokens=2048,
device="cuda",
):
org_prompt_cot = prompt_cot
try:
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt_format},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
generated_ids = self.model.generate(
**inputs,
max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数)
temperature=float(temperature),
do_sample=False,
top_k=5,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_res = output_text[0]
assert output_res.count("think>") == 2
prompt_cot = output_res.split("think>")[-1]
if prompt_cot.startswith("\n"):
prompt_cot = prompt_cot[1:]
prompt_cot = replace_single_quotes(prompt_cot)
except Exception as e:
prompt_cot = org_prompt_cot
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
return prompt_cot
# -------------------------
# Gradio app helpers
# -------------------------
DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
def ensure_enhancer(state, model_path, device_map, torch_dtype):
"""
state: dict or None
Returns: (state_dict)
"""
need_reload = False
if state is None or not isinstance(state, dict):
need_reload = True
else:
prev_path = state.get("model_path")
prev_map = state.get("device_map")
prev_dtype = state.get("torch_dtype")
if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
need_reload = True
if need_reload:
enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
return state
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state):
if not prompt or not str(prompt).strip():
return "", "请先输入提示词。", state
t0 = time.time()
state = ensure_enhancer(state, model_path, device_map, torch_dtype)
enhancer = state["enhancer"]
try:
out = enhancer.predict(
prompt_cot=prompt,
sys_prompt=sys_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
device=device
)
dt = time.time() - t0
return out, f"耗时:{dt:.2f}s", state
except Exception as e:
return "", f"推理失败:{e}", state
# 示例数据
test_list_zh = [
"第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
"韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
"点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
"一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
"Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
"Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
"Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
"A blend of expressionist and vintage styles, drawing a building with colorful walls.",
"Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]
with gr.Blocks(title="Prompt Enhancer_V2") as demo:
gr.Markdown("## 提示词重写器")
with gr.Row():
with gr.Column(scale=2):
model_path = gr.Textbox(
label="模型路径(本地或HF地址)",
value=DEFAULT_MODEL_PATH,
placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
)
device_map = gr.Dropdown(
choices=["cuda", "cpu"],
value="cuda",
label="device_map(模型加载映射)"
)
torch_dtype = gr.Dropdown(
choices=["bfloat16", "float16", "float32"],
value="bfloat16",
label="torch_dtype"
)
with gr.Column(scale=3):
sys_prompt = gr.Textbox(
label="系统提示词(默认无需修改)",
value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
lines=3
)
with gr.Row():
temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
state = gr.State(value=None)
with gr.Tab("推理"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
run_btn = gr.Button("生成重写", variant="primary")
gr.Examples(
examples=test_list_zh + test_list_en,
inputs=prompt,
label="示例"
)
with gr.Column(scale=3):
out_text = gr.Textbox(label="重写结果", lines=10)
out_info = gr.Markdown("准备就绪。")
# run_btn.click(
# stream_single,
# inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
# model_path, device_map, torch_dtype, state],
# outputs=[out_text, out_info, state]
# )
run_btn.click(
run_single,
inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state],
outputs=[out_text, out_info, state]
)
gr.Markdown(
"提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"
)
# 为避免多并发导致显存爆,限制并发
# demo.queue(concurrency_count=1, max_size=10)
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
demo.launch(ssr_mode=False, show_error=True, share=True) |