Spaces:
Build error
Build error
File size: 1,474 Bytes
96a5f85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
from transformers import pipeline
llm = pipeline("text-generation", model="gpt2") # 替换为你喜欢的模型
def run_chain(prompts):
results = []
context = ""
for p in prompts:
prompt = p.replace("{input}", context)
response = llm(prompt, max_new_tokens=100)[0]["generated_text"]
# 提取新增部分
new_output = response[len(prompt):].strip()
results.append(new_output)
context = new_output
return results
with gr.Blocks() as demo:
prompt_boxes = []
output_boxes = []
with gr.Row():
prompt_list = gr.State([])
out_list = gr.State([])
def add_prompt(prompt_list, out_list):
prompt_list.append("")
out_list.append("")
return prompt_list, out_list
add_btn = gr.Button("➕ 添加一个 Prompt")
run_btn = gr.Button("🚀 运行 Chain")
prompts_display = gr.Textbox(lines=15, label="Prompt Chain (用{input}引用上一轮输出)", interactive=True, elem_id="prompt_chain")
outputs_display = gr.Textbox(lines=15, label="模型输出", interactive=False)
def run_all(prompts_str):
prompts = prompts_str.strip().split("\n---\n")
outputs = run_chain(prompts)
return "\n---\n".join(outputs)
add_btn.click(fn=add_prompt, inputs=[prompt_list, out_list], outputs=[prompt_list, out_list])
run_btn.click(fn=run_all, inputs=prompts_display, outputs=outputs_display)
demo.launch()
|