File size: 9,554 Bytes
aa6e6b1
 
 
 
 
 
 
 
 
 
 
714c472
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfb21b4
e219ec6
62b1e28
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714c472
cfb21b4
aa6e6b1
 
 
 
bd177b1
aa6e6b1
 
bd177b1
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327846c
 
 
 
aa6e6b1
327846c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccae5de
 
 
 
 
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327846c
 
 
 
 
 
aa6e6b1
327846c
aa6e6b1
327846c
aa6e6b1
327846c
aa6e6b1
 
 
 
 
 
 
 
 
e1e36f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# app.py
# Gradio UI for PromptEnhancerV2

import os
from threading import Thread
from transformers import TextIteratorStreamer, AutoTokenizer
import time
import logging
import re
import torch
import gradio as gr
import spaces


# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
    from qwen_vl_utils import process_vision_info
except Exception:
    def process_vision_info(messages):
        return None, None

def replace_single_quotes(text):
    pattern = r"\B'([^']*)'\B"
    replaced_text = re.sub(pattern, r'"\1"', text)
    replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
    return replaced_text

class PromptEnhancerV2:
    @spaces.GPU 
    def __init__(self, models_root_path, device_map="auto", torch_dtype="bfloat16"):#auto
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
        if not logging.getLogger(__name__).handlers:
            logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # dtype 兼容处理
        if torch_dtype == "bfloat16":
            dtype = torch.bfloat16
        elif torch_dtype == "float16":
            dtype = torch.float16
        else:
            dtype = torch.float32

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            models_root_path,
            torch_dtype=dtype,
            device_map=device_map,
        )
        self.processor = AutoProcessor.from_pretrained(models_root_path)

    # @torch.inference_mode()
    @spaces.GPU
    def predict(
        self,
        prompt_cot,
        sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
        temperature=0.1,
        top_p=1.0,
        max_new_tokens=2048,
        device="cuda",
    ):
        org_prompt_cot = prompt_cot
        try:
            user_prompt_format = sys_prompt + "\n" + org_prompt_cot
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": user_prompt_format},
                    ],
                }
            ]

            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(device)

            # 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=2048,  # 与原始代码保持一致(未使用 max_new_tokens 参数)
                temperature=float(temperature),
                do_sample=False,
                top_k=5,
                top_p=0.9
            )
            generated_ids_trimmed = [
                out_ids[len(in_ids):]
                for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
            output_res = output_text[0]
            assert output_res.count("think>") == 2
            prompt_cot = output_res.split("think>")[-1]
            if prompt_cot.startswith("\n"):
                prompt_cot = prompt_cot[1:]
            prompt_cot = replace_single_quotes(prompt_cot)
        except Exception as e:
            prompt_cot = org_prompt_cot
            print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")

        return prompt_cot
# -------------------------
# Gradio app helpers
# -------------------------

DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")

def ensure_enhancer(state, model_path, device_map, torch_dtype):
    """
    state: dict or None
    Returns: (state_dict)
    """
    need_reload = False
    if state is None or not isinstance(state, dict):
        need_reload = True
    else:
        prev_path = state.get("model_path")
        prev_map = state.get("device_map")
        prev_dtype = state.get("torch_dtype")
        if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
            need_reload = True

    if need_reload:
        enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
        return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
    return state


def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
               model_path, device_map, torch_dtype, state):
    if not prompt or not str(prompt).strip():
        return "", "请先输入提示词。", state

    t0 = time.time()
    state = ensure_enhancer(state, model_path, device_map, torch_dtype)
    enhancer = state["enhancer"]
    try:
        out = enhancer.predict(
            prompt_cot=prompt,
            sys_prompt=sys_prompt,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            device=device
        )
        dt = time.time() - t0
        return out, f"耗时:{dt:.2f}s", state
    except Exception as e:
        return "", f"推理失败:{e}", state
        
# 示例数据
test_list_zh = [
    "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
    "韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
    "点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
    "一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
    "Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
    "Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
    "Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
    "A blend of expressionist and vintage styles, drawing a building with colorful walls.",
    "Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]

with gr.Blocks(title="Prompt Enhancer_V2") as demo:
    gr.Markdown("## 提示词重写器")
    with gr.Row():
        with gr.Column(scale=2):
            model_path = gr.Textbox(
                label="模型路径(本地或HF地址)",
                value=DEFAULT_MODEL_PATH,
                placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
            )
            device_map = gr.Dropdown(
                choices=["cuda", "cpu"],
                value="cuda",
                label="device_map(模型加载映射)"
            )          
            torch_dtype = gr.Dropdown(
                choices=["bfloat16", "float16", "float32"],
                value="bfloat16",
                label="torch_dtype"
            )

        with gr.Column(scale=3):
            sys_prompt = gr.Textbox(
                label="系统提示词(默认无需修改)",
                value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
                lines=3
            )
            with gr.Row():
                temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
                max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
                device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")

    state = gr.State(value=None)

    with gr.Tab("推理"):
        with gr.Row():
            with gr.Column(scale=2):
                prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
                run_btn = gr.Button("生成重写", variant="primary")
                gr.Examples(
                    examples=test_list_zh + test_list_en,
                    inputs=prompt,
                    label="示例"
                )
            with gr.Column(scale=3):
                out_text = gr.Textbox(label="重写结果", lines=10)
                out_info = gr.Markdown("准备就绪。")

        # run_btn.click(
        #     stream_single,
        #     inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
        #     model_path, device_map, torch_dtype, state],
        #     outputs=[out_text, out_info, state]
        #     )
        run_btn.click(
            run_single,
            inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
                    model_path, device_map, torch_dtype, state],
            outputs=[out_text, out_info, state]
        )

    gr.Markdown(
        "提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"
    )

# 为避免多并发导致显存爆,限制并发
# demo.queue(concurrency_count=1, max_size=10)
if __name__ == "__main__":
    # demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
    demo.launch(ssr_mode=False, show_error=True, share=True)