File size: 4,669 Bytes
6628fd9
 
 
 
 
 
 
 
 
 
 
e777aaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6628fd9
4f50749
6628fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e777aaf
 
 
 
 
 
 
 
 
 
 
 
 
4f50749
6628fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f50749
e777aaf
 
6628fd9
 
 
e777aaf
 
 
 
6628fd9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from abc import ABC, abstractmethod
from typing import Any

from llama_index import load_index_from_storage
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.indices.response import ResponseMode

from core.helper import LifecycleHelper
from core.lifecycle import Lifecycle
from llama.service_context import ServiceContextManager
from llama.storage_context import StorageContextManager
# from few_shot import get_few_shot_template

from langchain import PromptTemplate, FewShotPromptTemplate
examples = [
    {
        "question": "戴帽卫衣可以穿了吗?",
        "answer":
            """
    可以的,颜色需要符合上衣标准要求。
    """
    },
     {
        "question": "下装的标准是什么?",
        "answer":
            """
1.伙伴可以穿着长裤或及膝短裤,也可以穿裙子(包括连衣裙),但需要是纯色并且长度及膝或过膝。伙伴不应穿着颜色不均匀的牛仔裤,宽松下垂、破洞或者做旧效果的牛仔裤也不能穿。出于安全考虑,伙伴也不应穿着皮裤、瑜伽裤、弹力纤维裤和紧身裤(包括黑色连裤袜)。
2.颜色要求:卡其色、深蓝色、深灰色、黑色。
"""
    }
]


def get_few_shot_template() -> str:
    template = "Question: {question}, answer: {answer}\n"
    rendered_strings = []
    for item in examples:
        rendered_string = template.format(**item)
        rendered_strings.append(rendered_string)
    output = "\n".join(rendered_strings)
    return output


class FAQRobot(ABC):
    @abstractmethod
    def ask(self, question: str) -> Any:
        pass


class AzureOpenAIFAQWikiRobot(FAQRobot):
    query_engine: BaseQueryEngine

    def __init__(self, query_engine: BaseQueryEngine) -> None:
        super().__init__()
        self.query_engine = query_engine

    def ask(self, question: str) -> Any:
        print("question: ", question)
        response = self.query_engine.query(question)
        print("response type: ", type(response))
        return response.__str__()


class FAQRobotManager(Lifecycle):
    @abstractmethod
    def get_robot(self) -> FAQRobot:
        pass


DEFAULT_QA_PROMPT_TMPL_PREFIX = (
    "Given examples below.\n"
    "---------------------\n"
)
DEFAULT_QA_PROMPT_TMPL_SUFFIX = (
    "---------------------\n"
    "Context information is below.\n"
    "---------------------\n"
    "{context_str}\n"
    "---------------------\n"
    "Given the context information and not prior knowledge, "
    "either say '不好意思,我从文档中无法找到答案' or answer the function: {query_str}\n"
)

class AzureFAQRobotManager(FAQRobotManager):
    service_context_manager: ServiceContextManager
    storage_context_manager: StorageContextManager
    query_engine: BaseQueryEngine

    def __init__(
        self,
        service_context_manager: ServiceContextManager,
        storage_context_manager: StorageContextManager,
    ) -> None:
        super().__init__()
        self.service_context_manager = service_context_manager
        self.storage_context_manager = storage_context_manager

    def get_robot(self) -> FAQRobot:
        return AzureOpenAIFAQWikiRobot(self.query_engine)

    def do_init(self) -> None:
        LifecycleHelper.initialize_if_possible(self.service_context_manager)
        LifecycleHelper.initialize_if_possible(self.storage_context_manager)

    def do_start(self) -> None:
        LifecycleHelper.start_if_possible(self.service_context_manager)
        LifecycleHelper.start_if_possible(self.storage_context_manager)
        index = load_index_from_storage(
            storage_context=self.storage_context_manager.storage_context,
            service_context=self.service_context_manager.get_service_context(),
        )
        from llama_index import Prompt
        few_shot_examples = get_few_shot_template()

        self.query_engine = index.as_query_engine(
            service_context=self.service_context_manager.get_service_context(),
            response_mode=ResponseMode.REFINE,
            similarity_top_k=2,
            text_qa_template=Prompt("\n".join([DEFAULT_QA_PROMPT_TMPL_PREFIX,
                                               few_shot_examples,
                                               DEFAULT_QA_PROMPT_TMPL_SUFFIX]))
        )

    def do_stop(self) -> None:
        LifecycleHelper.stop_if_possible(self.storage_context_manager)
        LifecycleHelper.stop_if_possible(self.service_context_manager)

    def do_dispose(self) -> None:
        LifecycleHelper.dispose_if_possible(self.storage_context_manager)
        LifecycleHelper.dispose_if_possible(self.service_context_manager)