general-deep-learning / pipeline /base /generation_runner.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""
文本生成 ActionRunner 基类模块
提供统一的文本生成功能,支持交互式、固定 prompts、随机 prompts 三种模式。
"""
from abc import ABC, abstractmethod
from env.runner import ActionRunner
from pipeline.base.generation import TextGenerator
from pipeline.base.prompts_strategy import random_prompts
class BaseGenerationRunner(ActionRunner, ABC):
"""
文本生成 ActionRunner 基类。
子类必须实现:
- _build_generator(): 构建并返回生成器实例
子类必须设置:
- fixed_prompts: 固定 prompts 列表(类属性)
子类可配置:
- title: 显示标题
- random_config: random_prompts 参数字典
"""
title = "文本生成器"
fixed_prompts = []
random_config = {"num_text": 10, "text_length": 20}
def __init__(self, resolve_pipeline_func):
self.pipeline = resolve_pipeline_func()
self.generator: TextGenerator = self._build_generator()
@abstractmethod
def _build_generator(self) -> TextGenerator:
"""子类实现:构建并返回生成器实例"""
pass
def run_interactive(self):
"""交互式文本生成"""
self.pipeline.log_config()
print("\n" + "=" * 60)
print(self.title)
print("=" * 60)
print("输入提示文本,模型将生成续写内容。")
print("输入 'quit', 'exit' 或 'q' 退出程序。")
print("=" * 60 + "\n")
while True:
try:
prompt = input("提示: ").strip()
if prompt.lower() in ["quit", "exit", "q"]:
print("退出程序。")
break
if not prompt:
print("提示不能为空,请重新输入。")
continue
print("正在生成...")
result = self.generator.generate_text(prompt)
print("\n" + "-" * 60)
print(f"提示: {prompt}")
print(f"生成: {result.text}{result.stop_reason}")
print("-" * 60 + "\n")
except KeyboardInterrupt:
print("\n\n检测到中断信号,退出程序。")
break
except Exception as e:
print(f"生成过程中出现错误: {e}")
print("请重新输入提示。\n")
def run_fixed(self):
"""固定 prompts 文本生成"""
self.pipeline.log_config()
print(f"{self.title} - 固定提示生成启动...")
print("\n" + "=" * 60)
print(f"{self.title} 固定提示生成结果")
print("=" * 60 + "\n")
for i, prompt in enumerate(self.fixed_prompts):
print(f"提示 {i + 1:2}: {prompt}")
result = self.generator.generate_text(prompt)
print(f"生成: {result.text}{result.stop_reason}\n")
def run_random(self):
"""随机 prompts 文本生成"""
self.pipeline.log_config()
print(f"{self.title} - Random Prompts 生成器启动...")
docs_ds = self.pipeline.dataset.doc_ds()
prompts_generator = random_prompts(**self.random_config)
prompts = prompts_generator(docs_ds)
print("\n" + "=" * 60)
print(f"{self.title} Random Prompts 生成结果")
print("=" * 60 + "\n")
for i, prompt in enumerate(prompts):
print(f"提示 {i + 1:2}: {prompt}")
result = self.generator.generate_text(prompt)
print(f"生成: {result.text}{result.stop_reason}\n")