Prompt-Chain / app.py
PepperKnight's picture
Create app.py
96a5f85 verified
import gradio as gr
from transformers import pipeline
llm = pipeline("text-generation", model="gpt2") # 替换为你喜欢的模型
def run_chain(prompts):
results = []
context = ""
for p in prompts:
prompt = p.replace("{input}", context)
response = llm(prompt, max_new_tokens=100)[0]["generated_text"]
# 提取新增部分
new_output = response[len(prompt):].strip()
results.append(new_output)
context = new_output
return results
with gr.Blocks() as demo:
prompt_boxes = []
output_boxes = []
with gr.Row():
prompt_list = gr.State([])
out_list = gr.State([])
def add_prompt(prompt_list, out_list):
prompt_list.append("")
out_list.append("")
return prompt_list, out_list
add_btn = gr.Button("➕ 添加一个 Prompt")
run_btn = gr.Button("🚀 运行 Chain")
prompts_display = gr.Textbox(lines=15, label="Prompt Chain (用{input}引用上一轮输出)", interactive=True, elem_id="prompt_chain")
outputs_display = gr.Textbox(lines=15, label="模型输出", interactive=False)
def run_all(prompts_str):
prompts = prompts_str.strip().split("\n---\n")
outputs = run_chain(prompts)
return "\n---\n".join(outputs)
add_btn.click(fn=add_prompt, inputs=[prompt_list, out_list], outputs=[prompt_list, out_list])
run_btn.click(fn=run_all, inputs=prompts_display, outputs=outputs_display)
demo.launch()