Spaces:
Runtime error
Runtime error
| # app.py | |
| # Gradio UI for PromptEnhancerV2 | |
| import os | |
| from threading import Thread | |
| from transformers import TextIteratorStreamer, AutoTokenizer | |
| import time | |
| import logging | |
| import re | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| # 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入) | |
| try: | |
| from qwen_vl_utils import process_vision_info | |
| except Exception: | |
| def process_vision_info(messages): | |
| return None, None | |
| def replace_single_quotes(text): | |
| pattern = r"\B'([^']*)'\B" | |
| replaced_text = re.sub(pattern, r'"\1"', text) | |
| replaced_text = replaced_text.replace("’", "”").replace("‘", "“") | |
| return replaced_text | |
| class PromptEnhancerV2: | |
| def __init__(self, models_root_path, device_map="cuda", torch_dtype="bfloat16"):#auto | |
| device_map = "cuda:0" | |
| if not logging.getLogger(__name__).handlers: | |
| logging.basicConfig(level=logging.INFO) | |
| self.logger = logging.getLogger(__name__) | |
| # dtype 兼容处理 | |
| if torch_dtype == "bfloat16": | |
| dtype = torch.bfloat16 | |
| elif torch_dtype == "float16": | |
| dtype = torch.float16 | |
| else: | |
| dtype = torch.float32 | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| models_root_path, | |
| torch_dtype=dtype, | |
| attn_implementation="flash_attention_2", | |
| device_map=device_map, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(models_root_path) | |
| # @torch.inference_mode() | |
| def predict( | |
| self, | |
| prompt_cot, | |
| sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:", | |
| temperature=0.0, | |
| top_p=1.0, | |
| max_new_tokens=2048, | |
| device="cuda:0", | |
| ): | |
| org_prompt_cot = prompt_cot | |
| try: | |
| user_prompt_format = sys_prompt + "\n" + org_prompt_cot | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": user_prompt_format}, | |
| ], | |
| } | |
| ] | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(device) | |
| # 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致 | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数) | |
| temperature=float(temperature), | |
| do_sample=False, | |
| top_k=5, | |
| top_p=0.9 | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| output_res = output_text[0] | |
| assert output_res.count("think>") == 2 | |
| prompt_cot = output_res.split("think>")[-1] | |
| if prompt_cot.startswith("\n"): | |
| prompt_cot = prompt_cot[1:] | |
| prompt_cot = replace_single_quotes(prompt_cot) | |
| except Exception as e: | |
| prompt_cot = org_prompt_cot | |
| print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}") | |
| return prompt_cot | |
| # @torch.inference_mode() | |
| def predict_stream( | |
| self, | |
| prompt_cot, | |
| sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:", | |
| temperature=0.1, | |
| top_p=1.0, | |
| max_new_tokens=2048, | |
| device="cuda:0", | |
| ): | |
| org_prompt_cot = prompt_cot | |
| # 组装输入,同 predict | |
| user_prompt_format = sys_prompt + "\n" + org_prompt_cot | |
| messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}] | |
| text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(device) | |
| # 取得 tokenizer(大多数情况下 processor.tokenizer 就有;加一个后备以防万一) | |
| tokenizer = getattr(self.processor, "tokenizer", None) | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(self.models_root_path, trust_remote_code=True) | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=float(temperature), | |
| do_sample=True, # 与原逻辑一致; 若要采样流式把这里改为 True | |
| top_k=5, | |
| top_p=0.9, | |
| streamer=streamer, | |
| ) | |
| # 子线程启动生成;主线程消费 streamer | |
| thread = Thread(target=self.model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| buffer = "" # 累积完整输出(含思考) | |
| emitted = "" # 已对外输出的“重写提示词”部分 | |
| already_stripped_newline = False | |
| try: | |
| for piece in streamer: | |
| buffer += piece | |
| part = buffer.split('assistant')[-1] | |
| delta = part[len(emitted):] | |
| if delta: | |
| emitted = part | |
| yield emitted # 将中间结果送给前端 | |
| finally: | |
| thread.join() | |
| # 如果始终没等到第二个 think>,回退到原始 prompt | |
| # if emitted.strip() == "": | |
| # yield replace_single_quotes(org_prompt_cot) | |
| try: | |
| assert emitted.count("think>") == 2 | |
| prompt_cot = emitted.split("think>")[-1] | |
| if prompt_cot.startswith("\n"): | |
| prompt_cot = prompt_cot[1:] | |
| prompt_cot = emitted.split('assistant')[-1] + '\n \n Recaption:'+replace_single_quotes(prompt_cot) | |
| # prompt_cot = replace_single_quotes(prompt_cot) | |
| yield prompt_cot | |
| except Exception as e: | |
| prompt_cot = org_prompt_cot | |
| print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}") | |
| yield prompt_cot | |
| # ------------------------- | |
| # Gradio app helpers | |
| # ------------------------- | |
| DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B") | |
| def ensure_enhancer(state, model_path, device_map, torch_dtype): | |
| """ | |
| state: dict or None | |
| Returns: (state_dict) | |
| """ | |
| need_reload = False | |
| if state is None or not isinstance(state, dict): | |
| need_reload = True | |
| else: | |
| prev_path = state.get("model_path") | |
| prev_map = state.get("device_map") | |
| prev_dtype = state.get("torch_dtype") | |
| if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype: | |
| need_reload = True | |
| if need_reload: | |
| enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype) | |
| return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype} | |
| return state | |
| def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device, | |
| model_path, device_map, torch_dtype, state): | |
| if not prompt or not str(prompt).strip(): | |
| yield "", "请先输入提示词。", state | |
| return | |
| t0 = time.time() | |
| state = ensure_enhancer(state, model_path, device_map, torch_dtype) | |
| enhancer = state["enhancer"] | |
| emitted = "" | |
| try: | |
| for chunk in enhancer.predict_stream( | |
| prompt_cot=prompt, | |
| sys_prompt=sys_prompt, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| device=device | |
| ): | |
| emitted = chunk | |
| info = f"已接收 {len(emitted)} 字符,用时 {time.time()-t0:.2f}s" | |
| yield emitted, info, state | |
| # 结束时再给一次最终状态(可选) | |
| yield emitted, f"完成。总耗时 {time.time()-t0:.2f}s", state | |
| except Exception as e: | |
| yield "", f"推理失败:{e}", state | |
| # 示例数据 | |
| test_list_zh = [ | |
| "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。", | |
| "韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。", | |
| "点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。", | |
| "一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。", | |
| ] | |
| test_list_en = [ | |
| "Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.", | |
| "Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.", | |
| "Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.", | |
| "A blend of expressionist and vintage styles, drawing a building with colorful walls.", | |
| "Paint a winter scene with crystalline ice hangings from an Antarctic research station.", | |
| ] | |
| with gr.Blocks(title="Prompt Enhancer_V2") as demo: | |
| gr.Markdown("## 提示词重写器") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_path = gr.Textbox( | |
| label="模型路径(本地或HF地址)", | |
| value=DEFAULT_MODEL_PATH, | |
| placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0", | |
| ) | |
| device_map = gr.Dropdown( | |
| choices=["auto", "cuda", "cpu"], | |
| value="auto", | |
| label="device_map(模型加载映射)" | |
| ) | |
| torch_dtype = gr.Dropdown( | |
| choices=["bfloat16", "float16", "float32"], | |
| value="bfloat16", | |
| label="torch_dtype" | |
| ) | |
| with gr.Column(scale=3): | |
| sys_prompt = gr.Textbox( | |
| label="系统提示词(默认无需修改)", | |
| value="请根据用户的输入,生成思考过程的思维链并改写提示词:", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature") | |
| max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)") | |
| device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device") | |
| state = gr.State(value=None) | |
| with gr.Tab("推理"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...") | |
| run_btn = gr.Button("生成重写", variant="primary") | |
| gr.Examples( | |
| examples=test_list_zh + test_list_en, | |
| inputs=prompt, | |
| label="示例" | |
| ) | |
| with gr.Column(scale=3): | |
| out_text = gr.Textbox(label="重写结果", lines=10) | |
| out_info = gr.Markdown("准备就绪。") | |
| run_btn.click( | |
| stream_single, | |
| inputs=[prompt, sys_prompt, temperature, max_new_tokens, device, | |
| model_path, device_map, torch_dtype, state], | |
| outputs=[out_text, out_info, state] | |
| ) | |
| gr.Markdown( | |
| "提示:如有任何问题可email联系:linqing1995@buaa.edu.cn" | |
| ) | |
| # 为避免多并发导致显存爆,限制并发 | |
| # demo.queue(concurrency_count=1, max_size=10) | |
| if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True) | |
| demo.launch(ssr_mode=False, show_error=True, share=True) |