File size: 6,691 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
System prompt builder and related strategies
"""
from typing import Optional, Any
import logging

from lpm_kernel.api.domains.kernel2.dto.chat_dto import ChatRequest
from lpm_kernel.api.domains.kernel2.services.role_service import role_service
from lpm_kernel.api.domains.kernel2.services.knowledge_service import (
    default_retriever,
    default_l1_retriever,
)
from lpm_kernel.L2.training_prompt import CONTEXT_PROMPT, MEMORY_PROMPT, JUDGE_PROMPT


logger = logging.getLogger(__name__)


class SystemPromptStrategy:
    """Base class for system prompt building strategies"""
    def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
        """Build system prompt"""
        raise NotImplementedError()


class BasePromptStrategy(SystemPromptStrategy):
    """Most basic system prompt building strategy"""
    def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
        """Return the basic system prompt"""
        # Try to find a system message in the messages
        if request.messages:
            for message in request.messages:
                if message.get('role') == 'system':
                    return message.get('content', '')
        
        # Default empty prompt if no system message found
        return ""

class ContextEnhancedStrategy(SystemPromptStrategy):
    """Context-enhanced system prompt building strategy"""
    def build_prompt(self, request: ChatRequest) -> str:
        """Build context-enhanced system prompt"""
        base_prompt = CONTEXT_PROMPT
        return base_prompt

class ContextCriticStrategy(SystemPromptStrategy):
    """Context-critic system prompt building strategy"""
    def build_prompt(self, request: ChatRequest) -> str:
        """Build context-critic system prompt"""
        base_prompt = JUDGE_PROMPT
        return base_prompt

class RoleBasedStrategy(SystemPromptStrategy):
    """Role-based system prompt building strategy"""
    def __init__(self, base_strategy: SystemPromptStrategy):
        self.base_strategy = base_strategy

    def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
        """Build system prompt based on role"""
        # Get role_id from metadata if available
        role_id = None
        if hasattr(request, 'metadata') and request.metadata:
            role_id = request.metadata.get('role_id')
        
        if role_id:
            role = role_service.get_role_by_uuid(role_id)
            if role:
                prompt = role.system_prompt
                logger.info(f"RoleBasedStrategy (from role): {prompt}")
                return prompt
                
        prompt = self.base_strategy.build_prompt(request, context)
        # logger.info(f"RoleBasedStrategy (from base): {prompt}")
        return prompt


class KnowledgeEnhancedStrategy(SystemPromptStrategy):
    """Knowledge-enhanced system prompt building strategy"""
    def __init__(self, base_strategy: SystemPromptStrategy):
        self.base_strategy = base_strategy

    def get_user_message(self, request: ChatRequest) -> str:
        """
        Get the last user message from messages field.
        """
        if request.messages:
            # Find the last message with role='user'
            for message in reversed(request.messages):
                if message.get('role') == 'user':
                    return message.get('content', '')
        
        return ''

    def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
        """Build knowledge-enhanced system prompt"""
        base_prompt = self.base_strategy.build_prompt(request, context)
        
        logger.info(f"KnowledgeEnhancedStrategy request: {request}")
        logger.info(f"KnowledgeEnhancedStrategy (from base): {base_prompt}")
        
        # Add knowledge retrieval results if enabled
        knowledge_sections = []
        user_message = self.get_user_message(request)
        
        # Get configuration from metadata if available
        enable_l0_retrieval = False
        enable_l1_retrieval = False
        role_id = None
        
        if hasattr(request, 'metadata') and request.metadata:
            enable_l0_retrieval = request.metadata.get('enable_l0_retrieval', False)
            enable_l1_retrieval = request.metadata.get('enable_l1_retrieval', False)
            role_id = request.metadata.get('role_id')
        
        # if role exists, role config has priority
        if role_id:
            role = role_service.get_role_by_uuid(role_id)
            if role:
                if role.enable_l0_retrieval:
                    l0_knowledge = default_retriever.retrieve(user_message)
                    if l0_knowledge:
                        knowledge_sections.append(f"Role knowledge:\n{l0_knowledge}")
                if role.enable_l1_retrieval:
                    l1_knowledge = default_l1_retriever.retrieve(user_message)
                    if l1_knowledge:
                        knowledge_sections.append(f"Reference shades:\n{l1_knowledge}")
        else:
            # Retrieve L0 knowledge if enabled
            if enable_l0_retrieval:
                l0_knowledge = default_retriever.retrieve(user_message)
                if l0_knowledge:
                    knowledge_sections.append(f"Reference knowledge:\n{l0_knowledge}")
            
            # Retrieve L1 knowledge if enabled
            if enable_l1_retrieval:
                l1_knowledge = default_l1_retriever.retrieve(user_message)
                if l1_knowledge:
                    knowledge_sections.append(f"Reference shades:\n{l1_knowledge}")
            
        if knowledge_sections:
            if len(base_prompt) == 0:
                prompt = "\n\n".join(knowledge_sections)
            else:
                prompt = base_prompt + "\n\n" + "\n\n".join(knowledge_sections)
            logger.info(f"KnowledgeEnhancedStrategy (with knowledge): {prompt}")
            return prompt
            
        # logger.info(f"KnowledgeEnhancedStrategy (no knowledge found): {base_prompt}")
        return base_prompt


class SystemPromptBuilder:
    """System prompt builder"""
    def __init__(self):
        self.strategy: Optional[SystemPromptStrategy] = None

    def set_strategy(self, strategy: SystemPromptStrategy):
        self.strategy = strategy

    def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
        if not self.strategy:
            raise ValueError("No strategy set for SystemPromptBuilder")
        prompt = self.strategy.build_prompt(request, context)
        # logger.info(f"Final system prompt: {prompt}")
        return prompt