anthropic-topic-segmentation / core /prompt_manager.py
Yeetek's picture
Upload 43 files
7d5083d verified
"""
Dynamic prompt handling and template management.
This module provides a comprehensive prompt management system that handles
custom prompt injection, template management, variable substitution,
and output format enforcement for the topic segmentation API.
"""
import re
import json
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
from datetime import datetime
from models.input import (
PromptConfiguration, PromptTemplate, LanguageCode, TranscriptSentence
)
from templates.prompts import PromptTemplateManager
from config.logging import get_logger
logger = get_logger(__name__)
class PromptValidationError(Exception):
"""Raised when prompt validation fails."""
pass
class TemplateVariableError(Exception):
"""Raised when template variable substitution fails."""
pass
@dataclass
class PromptValidationResult:
"""Result of prompt validation."""
is_valid: bool
warnings: List[str]
errors: List[str]
safety_score: float
estimated_tokens: int
@dataclass
class ProcessedPrompt:
"""A processed prompt ready for API consumption."""
user_prompt: str
system_prompt: str
template_used: PromptTemplate
variables_substituted: Dict[str, str]
validation_result: PromptValidationResult
processing_metadata: Dict[str, Any]
class PromptSafetyLevel(str, Enum):
"""Safety levels for prompt validation."""
SAFE = "safe"
WARNING = "warning"
UNSAFE = "unsafe"
class PromptManager:
"""
Comprehensive prompt management system.
Handles custom prompt injection, template management, variable substitution,
validation, and output format enforcement for topic segmentation.
"""
# Safety keywords that might indicate prompt injection attempts
SAFETY_KEYWORDS = {
"high_risk": [
"ignore", "forget", "disregard", "override", "bypass", "jailbreak",
"pretend", "roleplay", "act as", "simulate", "emulate"
],
"medium_risk": [
"system", "admin", "root", "execute", "run", "eval", "exec",
"script", "code", "function", "method", "class"
],
"format_breaking": [
"don't use json", "ignore format", "plain text", "no structure",
"free form", "unstructured", "raw output"
]
}
# Required output format elements
REQUIRED_FORMAT_ELEMENTS = [
"topic_name", "topic_type", "topic_detail", "start_sentence_index",
"end_sentence_index", "primary_speaker", "confidence_score"
]
# Template variables that can be substituted
TEMPLATE_VARIABLES = {
"language": "Target language for processing",
"business_domain": "Business domain or industry context",
"speaker_count": "Number of unique speakers",
"sentence_count": "Total number of sentences",
"duration": "Total duration in minutes",
"additional_instructions": "Additional custom instructions",
"timestamp": "Current timestamp",
"categories": "Available business categories"
}
def __init__(self):
"""Initialize the prompt manager."""
self.template_manager = PromptTemplateManager()
self.logger = get_logger(f"{__name__}.{self.__class__.__name__}")
# Cache for processed templates
self._template_cache: Dict[str, str] = {}
# Statistics
self.stats = {
"prompts_processed": 0,
"templates_used": {},
"validation_failures": 0,
"safety_warnings": 0
}
def process_prompt_configuration(
self,
config: PromptConfiguration,
context: Dict[str, Any],
sentences: Optional[List[TranscriptSentence]] = None
) -> ProcessedPrompt:
"""
Process a prompt configuration into a ready-to-use prompt.
Args:
config: Prompt configuration from the request
context: Additional context for variable substitution
sentences: Optional transcript sentences for context
Returns:
ProcessedPrompt with user and system prompts
"""
start_time = datetime.now()
try:
# Step 1: Get base template
if config.template == PromptTemplate.CUSTOM:
if not config.custom_prompt:
raise PromptValidationError("Custom prompt is required when template is CUSTOM")
base_prompt = config.custom_prompt
system_prompt = self._get_default_system_prompt(config, context)
else:
base_prompt = self.template_manager.get_template(
config.template,
config.language,
config.business_domain,
config.additional_instructions
)
system_prompt = self.template_manager.get_system_prompt(
config.template,
config.language,
config.business_domain
)
# Step 2: Prepare template variables
template_vars = self._prepare_template_variables(config, context, sentences)
# Step 3: Substitute variables
processed_user_prompt = self._substitute_variables(base_prompt, template_vars)
processed_system_prompt = self._substitute_variables(system_prompt, template_vars)
# Step 4: Add output format instructions
processed_user_prompt = self._add_output_format_instructions(
processed_user_prompt, config
)
# Step 5: Validate the processed prompt
validation_result = self._validate_prompt(processed_user_prompt, config)
# Step 6: Create processing metadata
processing_time = (datetime.now() - start_time).total_seconds()
metadata = {
"processing_time": processing_time,
"template_used": config.template.value,
"variables_count": len(template_vars),
"estimated_tokens": validation_result.estimated_tokens,
"safety_score": validation_result.safety_score,
"timestamp": datetime.now().isoformat()
}
# Update statistics
self.stats["prompts_processed"] += 1
self.stats["templates_used"][config.template.value] = (
self.stats["templates_used"].get(config.template.value, 0) + 1
)
if not validation_result.is_valid:
self.stats["validation_failures"] += 1
if validation_result.safety_score < 0.8:
self.stats["safety_warnings"] += 1
result = ProcessedPrompt(
user_prompt=processed_user_prompt,
system_prompt=processed_system_prompt,
template_used=config.template,
variables_substituted=template_vars,
validation_result=validation_result,
processing_metadata=metadata
)
self.logger.info(
f"Prompt processed successfully: template={config.template.value}, "
f"safety_score={validation_result.safety_score:.2f}, "
f"tokens={validation_result.estimated_tokens}"
)
return result
except Exception as e:
self.logger.error(f"Error processing prompt configuration: {str(e)}")
raise PromptValidationError(f"Failed to process prompt: {str(e)}")
def _prepare_template_variables(
self,
config: PromptConfiguration,
context: Dict[str, Any],
sentences: Optional[List[TranscriptSentence]] = None
) -> Dict[str, str]:
"""Prepare template variables for substitution."""
variables = {}
# Basic configuration variables
variables["language"] = config.language.value
variables["business_domain"] = config.business_domain or "General"
variables["additional_instructions"] = config.additional_instructions or ""
variables["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Context variables
variables.update({
str(k): str(v) for k, v in context.items()
if isinstance(v, (str, int, float, bool))
})
# Sentence-based variables
if sentences:
variables["sentence_count"] = str(len(sentences))
variables["speaker_count"] = str(len(set(s.speaker for s in sentences)))
if sentences:
duration_minutes = (sentences[-1].end_time - sentences[0].start_time) / 60
variables["duration"] = f"{duration_minutes:.1f}"
else:
variables["duration"] = "0.0"
else:
variables["sentence_count"] = "0"
variables["speaker_count"] = "0"
variables["duration"] = "0.0"
# Business categories
from models.output import TopicCategory
categories = [cat.value.replace('_', ' ').title() for cat in TopicCategory]
variables["categories"] = ", ".join(categories)
return variables
def _substitute_variables(self, template: str, variables: Dict[str, str]) -> str:
"""Substitute template variables in the prompt."""
try:
result = template
# First handle double braces {{variable}} - replace with single braces temporarily
for var_name, var_value in variables.items():
# Replace {{variable}} format first
double_brace_pattern = f"{{{{{var_name}}}}}"
if double_brace_pattern in result:
result = result.replace(double_brace_pattern, str(var_value))
# Then handle single braces {variable}
for var_name, var_value in variables.items():
single_brace_pattern = f"{{{var_name}}}"
if single_brace_pattern in result:
result = result.replace(single_brace_pattern, str(var_value))
# Check for unsubstituted variables
unsubstituted = re.findall(r'\{([^}]+)\}', result)
if unsubstituted:
self.logger.warning(f"Unsubstituted variables found: {unsubstituted}")
return result
except Exception as e:
raise TemplateVariableError(f"Variable substitution failed: {str(e)}")
def _add_output_format_instructions(
self,
prompt: str,
config: PromptConfiguration
) -> str:
"""Add output format instructions to the prompt."""
format_instructions = self.template_manager.get_output_format_instructions(
config.language
)
# Add format instructions if not already present
if "JSON" not in prompt.upper() or "FORMAT" not in prompt.upper():
prompt += f"\n\n{format_instructions}"
return prompt
def _validate_prompt(
self,
prompt: str,
config: PromptConfiguration
) -> PromptValidationResult:
"""Validate a processed prompt for safety and format compliance."""
warnings = []
errors = []
safety_score = 1.0
# Check prompt length
if len(prompt) < 50:
warnings.append("Prompt is very short and may not provide sufficient context")
elif len(prompt) > 10000:
warnings.append("Prompt is very long and may exceed token limits")
# Safety checks
safety_score = self._calculate_safety_score(prompt)
if safety_score < 0.5:
errors.append("Prompt contains high-risk content that may compromise output quality")
elif safety_score < 0.8:
warnings.append("Prompt contains potentially risky content")
# Format compliance checks
format_compliance = self._check_format_compliance(prompt)
if not format_compliance:
warnings.append("Prompt may not enforce proper JSON output format")
# Required elements check
missing_elements = self._check_required_elements(prompt)
if missing_elements:
warnings.append(f"Prompt may not request required fields: {missing_elements}")
# Estimate token count
estimated_tokens = self._estimate_token_count(prompt)
# Determine if valid
is_valid = len(errors) == 0 and safety_score >= 0.5
return PromptValidationResult(
is_valid=is_valid,
warnings=warnings,
errors=errors,
safety_score=safety_score,
estimated_tokens=estimated_tokens
)
def _calculate_safety_score(self, prompt: str) -> float:
"""Calculate safety score for a prompt."""
prompt_lower = prompt.lower()
score = 1.0
# Check for high-risk keywords
high_risk_count = sum(
1 for keyword in self.SAFETY_KEYWORDS["high_risk"]
if keyword in prompt_lower
)
score -= high_risk_count * 0.3
# Check for medium-risk keywords
medium_risk_count = sum(
1 for keyword in self.SAFETY_KEYWORDS["medium_risk"]
if keyword in prompt_lower
)
score -= medium_risk_count * 0.1
# Check for format-breaking keywords
format_breaking_count = sum(
1 for keyword in self.SAFETY_KEYWORDS["format_breaking"]
if keyword in prompt_lower
)
score -= format_breaking_count * 0.2
# Ensure score is between 0 and 1
return max(0.0, min(1.0, score))
def _check_format_compliance(self, prompt: str) -> bool:
"""Check if prompt enforces JSON output format."""
format_indicators = [
"json", "format", "structure", "array", "object",
"topic_name", "topic_type", "confidence_score"
]
prompt_lower = prompt.lower()
found_indicators = sum(1 for indicator in format_indicators if indicator in prompt_lower)
return found_indicators >= 3
def _check_required_elements(self, prompt: str) -> List[str]:
"""Check for missing required output elements in prompt."""
prompt_lower = prompt.lower()
missing_elements = []
for element in self.REQUIRED_FORMAT_ELEMENTS:
if element not in prompt_lower:
missing_elements.append(element)
return missing_elements
def _estimate_token_count(self, prompt: str) -> int:
"""Estimate token count for a prompt."""
# Simple estimation: ~1.3 tokens per word for English
word_count = len(prompt.split())
return int(word_count * 1.3)
def _get_default_system_prompt(
self,
config: PromptConfiguration,
context: Dict[str, Any]
) -> str:
"""Get default system prompt for custom prompts."""
return f"""You are an expert business analyst specializing in topic extraction from business conversations.
Your task is to analyze transcript content and extract meaningful business topics with actionable insights.
Language: {config.language.value}
Business Domain: {config.business_domain or "General"}
Always respond with valid JSON format containing the required fields for each topic."""
def validate_custom_prompt(self, custom_prompt: str) -> PromptValidationResult:
"""Validate a custom prompt before processing."""
return self._validate_prompt(custom_prompt, PromptConfiguration())
def get_available_templates(self) -> List[Dict[str, str]]:
"""Get list of available prompt templates."""
templates = []
for template in PromptTemplate:
if template != PromptTemplate.CUSTOM:
templates.append({
"name": template.value,
"description": self.template_manager.get_template_description(template),
"supported_languages": [lang.value for lang in LanguageCode if lang != LanguageCode.AUTO_DETECT]
})
return templates
def get_template_variables(self) -> Dict[str, str]:
"""Get available template variables and their descriptions."""
return self.TEMPLATE_VARIABLES.copy()
def preview_template(
self,
template: PromptTemplate,
language: LanguageCode = LanguageCode.ENGLISH,
business_domain: Optional[str] = None,
variables: Optional[Dict[str, str]] = None
) -> str:
"""Preview a template with optional variable substitution."""
try:
# Get base template
base_template = self.template_manager.get_template(
template, language, business_domain
)
# Apply variables if provided
if variables:
base_template = self._substitute_variables(base_template, variables)
return base_template
except Exception as e:
self.logger.error(f"Error previewing template: {str(e)}")
raise PromptValidationError(f"Failed to preview template: {str(e)}")
def get_processing_stats(self) -> Dict[str, Any]:
"""Get prompt processing statistics."""
return {
"total_prompts_processed": self.stats["prompts_processed"],
"templates_usage": self.stats["templates_used"],
"validation_failure_rate": (
self.stats["validation_failures"] / max(1, self.stats["prompts_processed"])
),
"safety_warning_rate": (
self.stats["safety_warnings"] / max(1, self.stats["prompts_processed"])
),
"cache_size": len(self._template_cache)
}
def clear_cache(self) -> None:
"""Clear the template cache."""
self._template_cache.clear()
self.logger.info("Template cache cleared")
def reset_stats(self) -> None:
"""Reset processing statistics."""
self.stats = {
"prompts_processed": 0,
"templates_used": {},
"validation_failures": 0,
"safety_warnings": 0
}
self.logger.info("Processing statistics reset")
# Global prompt manager instance
_prompt_manager: Optional[PromptManager] = None
def get_prompt_manager() -> PromptManager:
"""Get or create global prompt manager instance."""
global _prompt_manager
if _prompt_manager is None:
_prompt_manager = PromptManager()
return _prompt_manager
def reset_prompt_manager() -> None:
"""Reset global prompt manager instance."""
global _prompt_manager
_prompt_manager = None