Spaces:
Sleeping
Sleeping
File size: 6,691 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
System prompt builder and related strategies
"""
from typing import Optional, Any
import logging
from lpm_kernel.api.domains.kernel2.dto.chat_dto import ChatRequest
from lpm_kernel.api.domains.kernel2.services.role_service import role_service
from lpm_kernel.api.domains.kernel2.services.knowledge_service import (
default_retriever,
default_l1_retriever,
)
from lpm_kernel.L2.training_prompt import CONTEXT_PROMPT, MEMORY_PROMPT, JUDGE_PROMPT
logger = logging.getLogger(__name__)
class SystemPromptStrategy:
"""Base class for system prompt building strategies"""
def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
"""Build system prompt"""
raise NotImplementedError()
class BasePromptStrategy(SystemPromptStrategy):
"""Most basic system prompt building strategy"""
def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
"""Return the basic system prompt"""
# Try to find a system message in the messages
if request.messages:
for message in request.messages:
if message.get('role') == 'system':
return message.get('content', '')
# Default empty prompt if no system message found
return ""
class ContextEnhancedStrategy(SystemPromptStrategy):
"""Context-enhanced system prompt building strategy"""
def build_prompt(self, request: ChatRequest) -> str:
"""Build context-enhanced system prompt"""
base_prompt = CONTEXT_PROMPT
return base_prompt
class ContextCriticStrategy(SystemPromptStrategy):
"""Context-critic system prompt building strategy"""
def build_prompt(self, request: ChatRequest) -> str:
"""Build context-critic system prompt"""
base_prompt = JUDGE_PROMPT
return base_prompt
class RoleBasedStrategy(SystemPromptStrategy):
"""Role-based system prompt building strategy"""
def __init__(self, base_strategy: SystemPromptStrategy):
self.base_strategy = base_strategy
def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
"""Build system prompt based on role"""
# Get role_id from metadata if available
role_id = None
if hasattr(request, 'metadata') and request.metadata:
role_id = request.metadata.get('role_id')
if role_id:
role = role_service.get_role_by_uuid(role_id)
if role:
prompt = role.system_prompt
logger.info(f"RoleBasedStrategy (from role): {prompt}")
return prompt
prompt = self.base_strategy.build_prompt(request, context)
# logger.info(f"RoleBasedStrategy (from base): {prompt}")
return prompt
class KnowledgeEnhancedStrategy(SystemPromptStrategy):
"""Knowledge-enhanced system prompt building strategy"""
def __init__(self, base_strategy: SystemPromptStrategy):
self.base_strategy = base_strategy
def get_user_message(self, request: ChatRequest) -> str:
"""
Get the last user message from messages field.
"""
if request.messages:
# Find the last message with role='user'
for message in reversed(request.messages):
if message.get('role') == 'user':
return message.get('content', '')
return ''
def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
"""Build knowledge-enhanced system prompt"""
base_prompt = self.base_strategy.build_prompt(request, context)
logger.info(f"KnowledgeEnhancedStrategy request: {request}")
logger.info(f"KnowledgeEnhancedStrategy (from base): {base_prompt}")
# Add knowledge retrieval results if enabled
knowledge_sections = []
user_message = self.get_user_message(request)
# Get configuration from metadata if available
enable_l0_retrieval = False
enable_l1_retrieval = False
role_id = None
if hasattr(request, 'metadata') and request.metadata:
enable_l0_retrieval = request.metadata.get('enable_l0_retrieval', False)
enable_l1_retrieval = request.metadata.get('enable_l1_retrieval', False)
role_id = request.metadata.get('role_id')
# if role exists, role config has priority
if role_id:
role = role_service.get_role_by_uuid(role_id)
if role:
if role.enable_l0_retrieval:
l0_knowledge = default_retriever.retrieve(user_message)
if l0_knowledge:
knowledge_sections.append(f"Role knowledge:\n{l0_knowledge}")
if role.enable_l1_retrieval:
l1_knowledge = default_l1_retriever.retrieve(user_message)
if l1_knowledge:
knowledge_sections.append(f"Reference shades:\n{l1_knowledge}")
else:
# Retrieve L0 knowledge if enabled
if enable_l0_retrieval:
l0_knowledge = default_retriever.retrieve(user_message)
if l0_knowledge:
knowledge_sections.append(f"Reference knowledge:\n{l0_knowledge}")
# Retrieve L1 knowledge if enabled
if enable_l1_retrieval:
l1_knowledge = default_l1_retriever.retrieve(user_message)
if l1_knowledge:
knowledge_sections.append(f"Reference shades:\n{l1_knowledge}")
if knowledge_sections:
if len(base_prompt) == 0:
prompt = "\n\n".join(knowledge_sections)
else:
prompt = base_prompt + "\n\n" + "\n\n".join(knowledge_sections)
logger.info(f"KnowledgeEnhancedStrategy (with knowledge): {prompt}")
return prompt
# logger.info(f"KnowledgeEnhancedStrategy (no knowledge found): {base_prompt}")
return base_prompt
class SystemPromptBuilder:
"""System prompt builder"""
def __init__(self):
self.strategy: Optional[SystemPromptStrategy] = None
def set_strategy(self, strategy: SystemPromptStrategy):
self.strategy = strategy
def build_prompt(self, request: ChatRequest, context: Optional[Any] = None) -> str:
if not self.strategy:
raise ValueError("No strategy set for SystemPromptBuilder")
prompt = self.strategy.build_prompt(request, context)
# logger.info(f"Final system prompt: {prompt}")
return prompt
|