aladdin1995's picture
Update app.py
e0e63ff verified
raw
history blame
12.8 kB
# app.py
# Gradio UI for PromptEnhancerV2
import os
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer
import time
import logging
import re
import torch
import gradio as gr
import spaces
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
from qwen_vl_utils import process_vision_info
except Exception:
def process_vision_info(messages):
return None, None
def replace_single_quotes(text):
pattern = r"\B'([^']*)'\B"
replaced_text = re.sub(pattern, r'"\1"', text)
replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
return replaced_text
class PromptEnhancerV2:
def __init__(self, models_root_path, device_map="cuda", torch_dtype="bfloat16"):#auto
device_map = "cuda:0"
if not logging.getLogger(__name__).handlers:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
# dtype 兼容处理
if torch_dtype == "bfloat16":
dtype = torch.bfloat16
elif torch_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
models_root_path,
torch_dtype=dtype,
attn_implementation="flash_attention_2",
device_map=device_map,
)
self.processor = AutoProcessor.from_pretrained(models_root_path)
# @torch.inference_mode()
@spaces.GPU
def predict(
self,
prompt_cot,
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
temperature=0.0,
top_p=1.0,
max_new_tokens=2048,
device="cuda:0",
):
org_prompt_cot = prompt_cot
try:
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt_format},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
generated_ids = self.model.generate(
**inputs,
max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数)
temperature=float(temperature),
do_sample=False,
top_k=5,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_res = output_text[0]
assert output_res.count("think>") == 2
prompt_cot = output_res.split("think>")[-1]
if prompt_cot.startswith("\n"):
prompt_cot = prompt_cot[1:]
prompt_cot = replace_single_quotes(prompt_cot)
except Exception as e:
prompt_cot = org_prompt_cot
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
return prompt_cot
# @torch.inference_mode()
@spaces.GPU
def predict_stream(
self,
prompt_cot,
sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
temperature=0.1,
top_p=1.0,
max_new_tokens=2048,
device="cuda:0",
):
org_prompt_cot = prompt_cot
# 组装输入,同 predict
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# 取得 tokenizer(大多数情况下 processor.tokenizer 就有;加一个后备以防万一)
tokenizer = getattr(self.processor, "tokenizer", None)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.models_root_path, trust_remote_code=True)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
gen_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=float(temperature),
do_sample=True, # 与原逻辑一致; 若要采样流式把这里改为 True
top_k=5,
top_p=0.9,
streamer=streamer,
)
# 子线程启动生成;主线程消费 streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
buffer = "" # 累积完整输出(含思考)
emitted = "" # 已对外输出的“重写提示词”部分
already_stripped_newline = False
try:
for piece in streamer:
buffer += piece
part = buffer.split('assistant')[-1]
delta = part[len(emitted):]
if delta:
emitted = part
yield emitted # 将中间结果送给前端
finally:
thread.join()
# 如果始终没等到第二个 think>,回退到原始 prompt
# if emitted.strip() == "":
# yield replace_single_quotes(org_prompt_cot)
try:
assert emitted.count("think>") == 2
prompt_cot = emitted.split("think>")[-1]
if prompt_cot.startswith("\n"):
prompt_cot = prompt_cot[1:]
prompt_cot = emitted.split('assistant')[-1] + '\n \n Recaption:'+replace_single_quotes(prompt_cot)
# prompt_cot = replace_single_quotes(prompt_cot)
yield prompt_cot
except Exception as e:
prompt_cot = org_prompt_cot
print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
yield prompt_cot
# -------------------------
# Gradio app helpers
# -------------------------
DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
def ensure_enhancer(state, model_path, device_map, torch_dtype):
"""
state: dict or None
Returns: (state_dict)
"""
need_reload = False
if state is None or not isinstance(state, dict):
need_reload = True
else:
prev_path = state.get("model_path")
prev_map = state.get("device_map")
prev_dtype = state.get("torch_dtype")
if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
need_reload = True
if need_reload:
enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
return state
def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state):
if not prompt or not str(prompt).strip():
yield "", "请先输入提示词。", state
return
t0 = time.time()
state = ensure_enhancer(state, model_path, device_map, torch_dtype)
enhancer = state["enhancer"]
emitted = ""
try:
for chunk in enhancer.predict_stream(
prompt_cot=prompt,
sys_prompt=sys_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
device=device
):
emitted = chunk
info = f"已接收 {len(emitted)} 字符,用时 {time.time()-t0:.2f}s"
yield emitted, info, state
# 结束时再给一次最终状态(可选)
yield emitted, f"完成。总耗时 {time.time()-t0:.2f}s", state
except Exception as e:
yield "", f"推理失败:{e}", state
# 示例数据
test_list_zh = [
"第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
"韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
"点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
"一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
"Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
"Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
"Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
"A blend of expressionist and vintage styles, drawing a building with colorful walls.",
"Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]
with gr.Blocks(title="Prompt Enhancer_V2") as demo:
gr.Markdown("## 提示词重写器")
with gr.Row():
with gr.Column(scale=2):
model_path = gr.Textbox(
label="模型路径(本地或HF地址)",
value=DEFAULT_MODEL_PATH,
placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
)
device_map = gr.Dropdown(
choices=["auto", "cuda", "cpu"],
value="auto",
label="device_map(模型加载映射)"
)
torch_dtype = gr.Dropdown(
choices=["bfloat16", "float16", "float32"],
value="bfloat16",
label="torch_dtype"
)
with gr.Column(scale=3):
sys_prompt = gr.Textbox(
label="系统提示词(默认无需修改)",
value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
lines=3
)
with gr.Row():
temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
state = gr.State(value=None)
with gr.Tab("推理"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
run_btn = gr.Button("生成重写", variant="primary")
gr.Examples(
examples=test_list_zh + test_list_en,
inputs=prompt,
label="示例"
)
with gr.Column(scale=3):
out_text = gr.Textbox(label="重写结果", lines=10)
out_info = gr.Markdown("准备就绪。")
run_btn.click(
stream_single,
inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state],
outputs=[out_text, out_info, state]
)
gr.Markdown(
"提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"
)
# 为避免多并发导致显存爆,限制并发
# demo.queue(concurrency_count=1, max_size=10)
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
demo.launch(ssr_mode=False, show_error=True, share=True)