File size: 3,572 Bytes
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
文本生成 ActionRunner 基类模块

提供统一的文本生成功能,支持交互式、固定 prompts、随机 prompts 三种模式。
"""

from abc import ABC, abstractmethod

from env.runner import ActionRunner
from pipeline.base.generation import TextGenerator
from pipeline.base.prompts_strategy import random_prompts


class BaseGenerationRunner(ActionRunner, ABC):
    """
    文本生成 ActionRunner 基类。

    子类必须实现:
    - _build_generator(): 构建并返回生成器实例

    子类必须设置:
    - fixed_prompts: 固定 prompts 列表(类属性)

    子类可配置:
    - title: 显示标题
    - random_config: random_prompts 参数字典
    """

    title = "文本生成器"
    fixed_prompts = []
    random_config = {"num_text": 10, "text_length": 20}

    def __init__(self, resolve_pipeline_func):
        self.pipeline = resolve_pipeline_func()
        self.generator: TextGenerator = self._build_generator()

    @abstractmethod
    def _build_generator(self) -> TextGenerator:
        """子类实现:构建并返回生成器实例"""
        pass

    def run_interactive(self):
        """交互式文本生成"""
        self.pipeline.log_config()
        print("\n" + "=" * 60)
        print(self.title)
        print("=" * 60)
        print("输入提示文本,模型将生成续写内容。")
        print("输入 'quit', 'exit' 或 'q' 退出程序。")
        print("=" * 60 + "\n")

        while True:
            try:
                prompt = input("提示: ").strip()

                if prompt.lower() in ["quit", "exit", "q"]:
                    print("退出程序。")
                    break

                if not prompt:
                    print("提示不能为空,请重新输入。")
                    continue

                print("正在生成...")
                result = self.generator.generate_text(prompt)

                print("\n" + "-" * 60)
                print(f"提示: {prompt}")
                print(f"生成: {result.text}{result.stop_reason}")
                print("-" * 60 + "\n")

            except KeyboardInterrupt:
                print("\n\n检测到中断信号,退出程序。")
                break
            except Exception as e:
                print(f"生成过程中出现错误: {e}")
                print("请重新输入提示。\n")

    def run_fixed(self):
        """固定 prompts 文本生成"""
        self.pipeline.log_config()
        print(f"{self.title} - 固定提示生成启动...")

        print("\n" + "=" * 60)
        print(f"{self.title} 固定提示生成结果")
        print("=" * 60 + "\n")

        for i, prompt in enumerate(self.fixed_prompts):
            print(f"提示 {i + 1:2}: {prompt}")
            result = self.generator.generate_text(prompt)
            print(f"生成: {result.text}{result.stop_reason}\n")

    def run_random(self):
        """随机 prompts 文本生成"""
        self.pipeline.log_config()
        print(f"{self.title} - Random Prompts 生成器启动...")

        docs_ds = self.pipeline.dataset.doc_ds()
        prompts_generator = random_prompts(**self.random_config)
        prompts = prompts_generator(docs_ds)

        print("\n" + "=" * 60)
        print(f"{self.title} Random Prompts 生成结果")
        print("=" * 60 + "\n")

        for i, prompt in enumerate(prompts):
            print(f"提示 {i + 1:2}: {prompt}")
            result = self.generator.generate_text(prompt)
            print(f"生成: {result.text}{result.stop_reason}\n")