File size: 1,474 Bytes
96a5f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from transformers import pipeline

llm = pipeline("text-generation", model="gpt2")  # 替换为你喜欢的模型

def run_chain(prompts):
    results = []
    context = ""
    for p in prompts:
        prompt = p.replace("{input}", context)
        response = llm(prompt, max_new_tokens=100)[0]["generated_text"]
        # 提取新增部分
        new_output = response[len(prompt):].strip()
        results.append(new_output)
        context = new_output
    return results

with gr.Blocks() as demo:
    prompt_boxes = []
    output_boxes = []

    with gr.Row():
        prompt_list = gr.State([])
        out_list = gr.State([])

    def add_prompt(prompt_list, out_list):
        prompt_list.append("")
        out_list.append("")
        return prompt_list, out_list

    add_btn = gr.Button("➕ 添加一个 Prompt")
    run_btn = gr.Button("🚀 运行 Chain")

    prompts_display = gr.Textbox(lines=15, label="Prompt Chain (用{input}引用上一轮输出)", interactive=True, elem_id="prompt_chain")
    outputs_display = gr.Textbox(lines=15, label="模型输出", interactive=False)

    def run_all(prompts_str):
        prompts = prompts_str.strip().split("\n---\n")
        outputs = run_chain(prompts)
        return "\n---\n".join(outputs)

    add_btn.click(fn=add_prompt, inputs=[prompt_list, out_list], outputs=[prompt_list, out_list])
    run_btn.click(fn=run_all, inputs=prompts_display, outputs=outputs_display)

demo.launch()