""" 文本生成 ActionRunner 基类模块 提供统一的文本生成功能,支持交互式、固定 prompts、随机 prompts 三种模式。 """ from abc import ABC, abstractmethod from env.runner import ActionRunner from pipeline.base.generation import TextGenerator from pipeline.base.prompts_strategy import random_prompts class BaseGenerationRunner(ActionRunner, ABC): """ 文本生成 ActionRunner 基类。 子类必须实现: - _build_generator(): 构建并返回生成器实例 子类必须设置: - fixed_prompts: 固定 prompts 列表(类属性) 子类可配置: - title: 显示标题 - random_config: random_prompts 参数字典 """ title = "文本生成器" fixed_prompts = [] random_config = {"num_text": 10, "text_length": 20} def __init__(self, resolve_pipeline_func): self.pipeline = resolve_pipeline_func() self.generator: TextGenerator = self._build_generator() @abstractmethod def _build_generator(self) -> TextGenerator: """子类实现:构建并返回生成器实例""" pass def run_interactive(self): """交互式文本生成""" self.pipeline.log_config() print("\n" + "=" * 60) print(self.title) print("=" * 60) print("输入提示文本,模型将生成续写内容。") print("输入 'quit', 'exit' 或 'q' 退出程序。") print("=" * 60 + "\n") while True: try: prompt = input("提示: ").strip() if prompt.lower() in ["quit", "exit", "q"]: print("退出程序。") break if not prompt: print("提示不能为空,请重新输入。") continue print("正在生成...") result = self.generator.generate_text(prompt) print("\n" + "-" * 60) print(f"提示: {prompt}") print(f"生成: {result.text}{result.stop_reason}") print("-" * 60 + "\n") except KeyboardInterrupt: print("\n\n检测到中断信号,退出程序。") break except Exception as e: print(f"生成过程中出现错误: {e}") print("请重新输入提示。\n") def run_fixed(self): """固定 prompts 文本生成""" self.pipeline.log_config() print(f"{self.title} - 固定提示生成启动...") print("\n" + "=" * 60) print(f"{self.title} 固定提示生成结果") print("=" * 60 + "\n") for i, prompt in enumerate(self.fixed_prompts): print(f"提示 {i + 1:2}: {prompt}") result = self.generator.generate_text(prompt) print(f"生成: {result.text}{result.stop_reason}\n") def run_random(self): """随机 prompts 文本生成""" self.pipeline.log_config() print(f"{self.title} - Random Prompts 生成器启动...") docs_ds = self.pipeline.dataset.doc_ds() prompts_generator = random_prompts(**self.random_config) prompts = prompts_generator(docs_ds) print("\n" + "=" * 60) print(f"{self.title} Random Prompts 生成结果") print("=" * 60 + "\n") for i, prompt in enumerate(prompts): print(f"提示 {i + 1:2}: {prompt}") result = self.generator.generate_text(prompt) print(f"生成: {result.text}{result.stop_reason}\n")