Preformu / layers /prompt_orchestrator.py
Kevinshh's picture
Upload full project
aecf8ce verified
"""
Prompt Orchestration Layer.
This is the CORE module of the application. It coordinates:
1. Selection of appropriate prompts based on input
2. Construction of analysis prompts for each dimension
3. Sequencing of multi-prompt analyses
4. Aggregation of results for output normalization
Design Philosophy:
- Prompts are treated as intellectual property
- Each analysis dimension is independent but can be combined
- The orchestrator manages the overall analysis flow
"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from schemas.canonical_schema import (
AnalysisRequest,
AnalysisResult,
BilingualText,
RiskLevel,
)
from prompts.api_structure import APIStructurePrompts
from prompts.excipient_analysis import ExcipientAnalysisPrompts
from prompts.compatibility import CompatibilityPrompts
from prompts.stability import StabilityPrompts
from prompts.synthesis import SynthesisPrompts
@dataclass
class PromptPackage:
"""
A package of prompts ready for model invocation.
Each package represents one analysis dimension and contains:
- The dimension identifier
- System prompt for context
- Main analysis prompt
- Optional follow-up prompts
"""
dimension: str
system_prompt: str
main_prompt: str
follow_up_prompts: List[str] = None
def __post_init__(self):
if self.follow_up_prompts is None:
self.follow_up_prompts = []
class PromptOrchestrator:
"""
Orchestrates the construction and sequencing of analysis prompts.
This class is responsible for:
1. Analyzing the input request to determine required analyses
2. Constructing appropriate prompts for each dimension
3. Managing the analysis sequence
4. Preparing final synthesis prompts
"""
def __init__(self):
"""Initialize the orchestrator with prompt templates."""
self.api_prompts = APIStructurePrompts()
self.excipient_prompts = ExcipientAnalysisPrompts()
self.compatibility_prompts = CompatibilityPrompts()
self.stability_prompts = StabilityPrompts()
self.synthesis_prompts = SynthesisPrompts()
def orchestrate(self, request: AnalysisRequest) -> List[PromptPackage]:
"""
Orchestrate the full analysis flow for a given request.
Args:
request: The normalized analysis request
Returns:
List of PromptPackages ready for model invocation
"""
packages = []
# Step 1: API Structure Analysis (if SMILES provided)
if request.api.smiles:
packages.append(
self._build_api_structure_package(request)
)
# Step 2: Excipient Analysis (for each excipient)
for excipient in request.excipients:
packages.append(
self._build_excipient_package(excipient)
)
# Step 3: Compatibility Analysis (API + each excipient)
if request.api.smiles and request.excipients:
for excipient in request.excipients:
packages.append(
self._build_compatibility_package(request.api, excipient)
)
# Step 4: Stability Data Interpretation (if data provided)
if request.stability_data:
packages.append(
self._build_stability_package(request)
)
return packages
def build_synthesis_prompt(
self,
request: AnalysisRequest,
dimension_results: Dict[str, str]
) -> PromptPackage:
"""
Build the final synthesis prompt after all dimensions are analyzed.
Args:
request: The original analysis request
dimension_results: Results from each analysis dimension
Returns:
PromptPackage for final synthesis
"""
# Extract results for each dimension
api_analysis = dimension_results.get("api_structure", "未进行API结构分析")
excipient_analysis = dimension_results.get("excipient_analysis", "未进行辅料分析")
compatibility_analysis = dimension_results.get("compatibility", "未进行相容性分析")
stability_analysis = dimension_results.get("stability", "")
# Get API name (prefer explicit name, fallback to SMILES excerpt)
api_name = request.api.name or f"Compound ({request.api.smiles[:20]}...)" if request.api.smiles else "Unknown API"
# Get primary excipient name
excipient_name = request.excipients[0].name if request.excipients else "未指定辅料"
return PromptPackage(
dimension="synthesis",
system_prompt=self.synthesis_prompts.get_system_prompt(),
main_prompt=self.synthesis_prompts.get_synthesis_prompt(
api_name=api_name,
excipient_name=excipient_name,
api_analysis=api_analysis,
excipient_analysis=excipient_analysis,
compatibility_analysis=compatibility_analysis,
stability_analysis=stability_analysis,
)
)
def _build_api_structure_package(
self,
request: AnalysisRequest
) -> PromptPackage:
"""Build prompt package for API structure analysis."""
return PromptPackage(
dimension="api_structure",
system_prompt=self.api_prompts.get_system_prompt(),
main_prompt=self.api_prompts.get_structure_analysis_prompt(
smiles=request.api.smiles,
additional_info=request.api.additional_info or "",
),
follow_up_prompts=[
self.api_prompts.get_physicochemical_prompt(request.api.smiles),
]
)
def _build_excipient_package(
self,
excipient
) -> PromptPackage:
"""Build prompt package for excipient analysis."""
return PromptPackage(
dimension="excipient_analysis",
system_prompt=self.excipient_prompts.get_system_prompt(),
main_prompt=self.excipient_prompts.get_excipient_profile_prompt(
excipient_name=excipient.name,
grade=excipient.grade or "",
additional_info=excipient.additional_info or "",
),
follow_up_prompts=[
self.excipient_prompts.get_impurity_risk_prompt(excipient.name),
]
)
def _build_compatibility_package(
self,
api,
excipient
) -> PromptPackage:
"""Build prompt package for compatibility analysis."""
api_name = api.name or f"Compound ({api.smiles[:20]}...)" if api.smiles else "Unknown"
return PromptPackage(
dimension="compatibility",
system_prompt=self.compatibility_prompts.get_system_prompt(),
main_prompt=self.compatibility_prompts.get_comprehensive_compatibility_prompt(
api_smiles=api.smiles,
api_name=api_name,
excipient_name=excipient.name,
reactive_groups="", # Will be filled from API analysis
excipient_properties="", # Will be filled from excipient analysis
)
)
def _build_stability_package(
self,
request: AnalysisRequest
) -> PromptPackage:
"""Build prompt package for stability data interpretation."""
api_name = request.api.name or "Study Sample"
stability = request.stability_data
return PromptPackage(
dimension="stability",
system_prompt=self.stability_prompts.get_system_prompt(),
main_prompt=self.stability_prompts.get_data_interpretation_prompt(
api_name=api_name,
conditions=stability.conditions or "未指定",
observations=stability.observations or "无观察结果描述",
duration=stability.duration or "",
)
)
def get_analysis_dimensions(self, request: AnalysisRequest) -> List[str]:
"""
Determine which analysis dimensions are applicable.
Returns a list of dimension identifiers that should be executed.
"""
dimensions = []
if request.api.smiles:
dimensions.append("api_structure")
if request.excipients:
dimensions.append("excipient_analysis")
if request.api.smiles and request.excipients:
dimensions.append("compatibility")
if request.stability_data:
dimensions.append("stability")
# Always include synthesis if we have something to synthesize
if dimensions:
dimensions.append("synthesis")
return dimensions
class PromptChain:
"""
Represents a chain of prompts to be executed in sequence.
Some analyses benefit from chained prompts where the output
of one informs the next. This class manages such chains.
"""
def __init__(self, packages: List[PromptPackage]):
"""Initialize with a list of prompt packages."""
self.packages = packages
self.current_index = 0
self.results: Dict[str, str] = {}
def get_next(self) -> Optional[PromptPackage]:
"""Get the next prompt package in the chain."""
if self.current_index < len(self.packages):
package = self.packages[self.current_index]
self.current_index += 1
return package
return None
def store_result(self, dimension: str, result: str):
"""Store the result for a dimension."""
self.results[dimension] = result
def is_complete(self) -> bool:
"""Check if all prompts have been processed."""
return self.current_index >= len(self.packages)
def get_all_results(self) -> Dict[str, str]:
"""Get all collected results."""
return self.results.copy()