Spaces:
Runtime error
Runtime error
| """ | |
| Dynamic prompt handling and template management. | |
| This module provides a comprehensive prompt management system that handles | |
| custom prompt injection, template management, variable substitution, | |
| and output format enforcement for the topic segmentation API. | |
| """ | |
| import re | |
| import json | |
| from typing import Dict, List, Optional, Any, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from datetime import datetime | |
| from models.input import ( | |
| PromptConfiguration, PromptTemplate, LanguageCode, TranscriptSentence | |
| ) | |
| from templates.prompts import PromptTemplateManager | |
| from config.logging import get_logger | |
| logger = get_logger(__name__) | |
| class PromptValidationError(Exception): | |
| """Raised when prompt validation fails.""" | |
| pass | |
| class TemplateVariableError(Exception): | |
| """Raised when template variable substitution fails.""" | |
| pass | |
| class PromptValidationResult: | |
| """Result of prompt validation.""" | |
| is_valid: bool | |
| warnings: List[str] | |
| errors: List[str] | |
| safety_score: float | |
| estimated_tokens: int | |
| class ProcessedPrompt: | |
| """A processed prompt ready for API consumption.""" | |
| user_prompt: str | |
| system_prompt: str | |
| template_used: PromptTemplate | |
| variables_substituted: Dict[str, str] | |
| validation_result: PromptValidationResult | |
| processing_metadata: Dict[str, Any] | |
| class PromptSafetyLevel(str, Enum): | |
| """Safety levels for prompt validation.""" | |
| SAFE = "safe" | |
| WARNING = "warning" | |
| UNSAFE = "unsafe" | |
| class PromptManager: | |
| """ | |
| Comprehensive prompt management system. | |
| Handles custom prompt injection, template management, variable substitution, | |
| validation, and output format enforcement for topic segmentation. | |
| """ | |
| # Safety keywords that might indicate prompt injection attempts | |
| SAFETY_KEYWORDS = { | |
| "high_risk": [ | |
| "ignore", "forget", "disregard", "override", "bypass", "jailbreak", | |
| "pretend", "roleplay", "act as", "simulate", "emulate" | |
| ], | |
| "medium_risk": [ | |
| "system", "admin", "root", "execute", "run", "eval", "exec", | |
| "script", "code", "function", "method", "class" | |
| ], | |
| "format_breaking": [ | |
| "don't use json", "ignore format", "plain text", "no structure", | |
| "free form", "unstructured", "raw output" | |
| ] | |
| } | |
| # Required output format elements | |
| REQUIRED_FORMAT_ELEMENTS = [ | |
| "topic_name", "topic_type", "topic_detail", "start_sentence_index", | |
| "end_sentence_index", "primary_speaker", "confidence_score" | |
| ] | |
| # Template variables that can be substituted | |
| TEMPLATE_VARIABLES = { | |
| "language": "Target language for processing", | |
| "business_domain": "Business domain or industry context", | |
| "speaker_count": "Number of unique speakers", | |
| "sentence_count": "Total number of sentences", | |
| "duration": "Total duration in minutes", | |
| "additional_instructions": "Additional custom instructions", | |
| "timestamp": "Current timestamp", | |
| "categories": "Available business categories" | |
| } | |
| def __init__(self): | |
| """Initialize the prompt manager.""" | |
| self.template_manager = PromptTemplateManager() | |
| self.logger = get_logger(f"{__name__}.{self.__class__.__name__}") | |
| # Cache for processed templates | |
| self._template_cache: Dict[str, str] = {} | |
| # Statistics | |
| self.stats = { | |
| "prompts_processed": 0, | |
| "templates_used": {}, | |
| "validation_failures": 0, | |
| "safety_warnings": 0 | |
| } | |
| def process_prompt_configuration( | |
| self, | |
| config: PromptConfiguration, | |
| context: Dict[str, Any], | |
| sentences: Optional[List[TranscriptSentence]] = None | |
| ) -> ProcessedPrompt: | |
| """ | |
| Process a prompt configuration into a ready-to-use prompt. | |
| Args: | |
| config: Prompt configuration from the request | |
| context: Additional context for variable substitution | |
| sentences: Optional transcript sentences for context | |
| Returns: | |
| ProcessedPrompt with user and system prompts | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| # Step 1: Get base template | |
| if config.template == PromptTemplate.CUSTOM: | |
| if not config.custom_prompt: | |
| raise PromptValidationError("Custom prompt is required when template is CUSTOM") | |
| base_prompt = config.custom_prompt | |
| system_prompt = self._get_default_system_prompt(config, context) | |
| else: | |
| base_prompt = self.template_manager.get_template( | |
| config.template, | |
| config.language, | |
| config.business_domain, | |
| config.additional_instructions | |
| ) | |
| system_prompt = self.template_manager.get_system_prompt( | |
| config.template, | |
| config.language, | |
| config.business_domain | |
| ) | |
| # Step 2: Prepare template variables | |
| template_vars = self._prepare_template_variables(config, context, sentences) | |
| # Step 3: Substitute variables | |
| processed_user_prompt = self._substitute_variables(base_prompt, template_vars) | |
| processed_system_prompt = self._substitute_variables(system_prompt, template_vars) | |
| # Step 4: Add output format instructions | |
| processed_user_prompt = self._add_output_format_instructions( | |
| processed_user_prompt, config | |
| ) | |
| # Step 5: Validate the processed prompt | |
| validation_result = self._validate_prompt(processed_user_prompt, config) | |
| # Step 6: Create processing metadata | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| metadata = { | |
| "processing_time": processing_time, | |
| "template_used": config.template.value, | |
| "variables_count": len(template_vars), | |
| "estimated_tokens": validation_result.estimated_tokens, | |
| "safety_score": validation_result.safety_score, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Update statistics | |
| self.stats["prompts_processed"] += 1 | |
| self.stats["templates_used"][config.template.value] = ( | |
| self.stats["templates_used"].get(config.template.value, 0) + 1 | |
| ) | |
| if not validation_result.is_valid: | |
| self.stats["validation_failures"] += 1 | |
| if validation_result.safety_score < 0.8: | |
| self.stats["safety_warnings"] += 1 | |
| result = ProcessedPrompt( | |
| user_prompt=processed_user_prompt, | |
| system_prompt=processed_system_prompt, | |
| template_used=config.template, | |
| variables_substituted=template_vars, | |
| validation_result=validation_result, | |
| processing_metadata=metadata | |
| ) | |
| self.logger.info( | |
| f"Prompt processed successfully: template={config.template.value}, " | |
| f"safety_score={validation_result.safety_score:.2f}, " | |
| f"tokens={validation_result.estimated_tokens}" | |
| ) | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"Error processing prompt configuration: {str(e)}") | |
| raise PromptValidationError(f"Failed to process prompt: {str(e)}") | |
| def _prepare_template_variables( | |
| self, | |
| config: PromptConfiguration, | |
| context: Dict[str, Any], | |
| sentences: Optional[List[TranscriptSentence]] = None | |
| ) -> Dict[str, str]: | |
| """Prepare template variables for substitution.""" | |
| variables = {} | |
| # Basic configuration variables | |
| variables["language"] = config.language.value | |
| variables["business_domain"] = config.business_domain or "General" | |
| variables["additional_instructions"] = config.additional_instructions or "" | |
| variables["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Context variables | |
| variables.update({ | |
| str(k): str(v) for k, v in context.items() | |
| if isinstance(v, (str, int, float, bool)) | |
| }) | |
| # Sentence-based variables | |
| if sentences: | |
| variables["sentence_count"] = str(len(sentences)) | |
| variables["speaker_count"] = str(len(set(s.speaker for s in sentences))) | |
| if sentences: | |
| duration_minutes = (sentences[-1].end_time - sentences[0].start_time) / 60 | |
| variables["duration"] = f"{duration_minutes:.1f}" | |
| else: | |
| variables["duration"] = "0.0" | |
| else: | |
| variables["sentence_count"] = "0" | |
| variables["speaker_count"] = "0" | |
| variables["duration"] = "0.0" | |
| # Business categories | |
| from models.output import TopicCategory | |
| categories = [cat.value.replace('_', ' ').title() for cat in TopicCategory] | |
| variables["categories"] = ", ".join(categories) | |
| return variables | |
| def _substitute_variables(self, template: str, variables: Dict[str, str]) -> str: | |
| """Substitute template variables in the prompt.""" | |
| try: | |
| result = template | |
| # First handle double braces {{variable}} - replace with single braces temporarily | |
| for var_name, var_value in variables.items(): | |
| # Replace {{variable}} format first | |
| double_brace_pattern = f"{{{{{var_name}}}}}" | |
| if double_brace_pattern in result: | |
| result = result.replace(double_brace_pattern, str(var_value)) | |
| # Then handle single braces {variable} | |
| for var_name, var_value in variables.items(): | |
| single_brace_pattern = f"{{{var_name}}}" | |
| if single_brace_pattern in result: | |
| result = result.replace(single_brace_pattern, str(var_value)) | |
| # Check for unsubstituted variables | |
| unsubstituted = re.findall(r'\{([^}]+)\}', result) | |
| if unsubstituted: | |
| self.logger.warning(f"Unsubstituted variables found: {unsubstituted}") | |
| return result | |
| except Exception as e: | |
| raise TemplateVariableError(f"Variable substitution failed: {str(e)}") | |
| def _add_output_format_instructions( | |
| self, | |
| prompt: str, | |
| config: PromptConfiguration | |
| ) -> str: | |
| """Add output format instructions to the prompt.""" | |
| format_instructions = self.template_manager.get_output_format_instructions( | |
| config.language | |
| ) | |
| # Add format instructions if not already present | |
| if "JSON" not in prompt.upper() or "FORMAT" not in prompt.upper(): | |
| prompt += f"\n\n{format_instructions}" | |
| return prompt | |
| def _validate_prompt( | |
| self, | |
| prompt: str, | |
| config: PromptConfiguration | |
| ) -> PromptValidationResult: | |
| """Validate a processed prompt for safety and format compliance.""" | |
| warnings = [] | |
| errors = [] | |
| safety_score = 1.0 | |
| # Check prompt length | |
| if len(prompt) < 50: | |
| warnings.append("Prompt is very short and may not provide sufficient context") | |
| elif len(prompt) > 10000: | |
| warnings.append("Prompt is very long and may exceed token limits") | |
| # Safety checks | |
| safety_score = self._calculate_safety_score(prompt) | |
| if safety_score < 0.5: | |
| errors.append("Prompt contains high-risk content that may compromise output quality") | |
| elif safety_score < 0.8: | |
| warnings.append("Prompt contains potentially risky content") | |
| # Format compliance checks | |
| format_compliance = self._check_format_compliance(prompt) | |
| if not format_compliance: | |
| warnings.append("Prompt may not enforce proper JSON output format") | |
| # Required elements check | |
| missing_elements = self._check_required_elements(prompt) | |
| if missing_elements: | |
| warnings.append(f"Prompt may not request required fields: {missing_elements}") | |
| # Estimate token count | |
| estimated_tokens = self._estimate_token_count(prompt) | |
| # Determine if valid | |
| is_valid = len(errors) == 0 and safety_score >= 0.5 | |
| return PromptValidationResult( | |
| is_valid=is_valid, | |
| warnings=warnings, | |
| errors=errors, | |
| safety_score=safety_score, | |
| estimated_tokens=estimated_tokens | |
| ) | |
| def _calculate_safety_score(self, prompt: str) -> float: | |
| """Calculate safety score for a prompt.""" | |
| prompt_lower = prompt.lower() | |
| score = 1.0 | |
| # Check for high-risk keywords | |
| high_risk_count = sum( | |
| 1 for keyword in self.SAFETY_KEYWORDS["high_risk"] | |
| if keyword in prompt_lower | |
| ) | |
| score -= high_risk_count * 0.3 | |
| # Check for medium-risk keywords | |
| medium_risk_count = sum( | |
| 1 for keyword in self.SAFETY_KEYWORDS["medium_risk"] | |
| if keyword in prompt_lower | |
| ) | |
| score -= medium_risk_count * 0.1 | |
| # Check for format-breaking keywords | |
| format_breaking_count = sum( | |
| 1 for keyword in self.SAFETY_KEYWORDS["format_breaking"] | |
| if keyword in prompt_lower | |
| ) | |
| score -= format_breaking_count * 0.2 | |
| # Ensure score is between 0 and 1 | |
| return max(0.0, min(1.0, score)) | |
| def _check_format_compliance(self, prompt: str) -> bool: | |
| """Check if prompt enforces JSON output format.""" | |
| format_indicators = [ | |
| "json", "format", "structure", "array", "object", | |
| "topic_name", "topic_type", "confidence_score" | |
| ] | |
| prompt_lower = prompt.lower() | |
| found_indicators = sum(1 for indicator in format_indicators if indicator in prompt_lower) | |
| return found_indicators >= 3 | |
| def _check_required_elements(self, prompt: str) -> List[str]: | |
| """Check for missing required output elements in prompt.""" | |
| prompt_lower = prompt.lower() | |
| missing_elements = [] | |
| for element in self.REQUIRED_FORMAT_ELEMENTS: | |
| if element not in prompt_lower: | |
| missing_elements.append(element) | |
| return missing_elements | |
| def _estimate_token_count(self, prompt: str) -> int: | |
| """Estimate token count for a prompt.""" | |
| # Simple estimation: ~1.3 tokens per word for English | |
| word_count = len(prompt.split()) | |
| return int(word_count * 1.3) | |
| def _get_default_system_prompt( | |
| self, | |
| config: PromptConfiguration, | |
| context: Dict[str, Any] | |
| ) -> str: | |
| """Get default system prompt for custom prompts.""" | |
| return f"""You are an expert business analyst specializing in topic extraction from business conversations. | |
| Your task is to analyze transcript content and extract meaningful business topics with actionable insights. | |
| Language: {config.language.value} | |
| Business Domain: {config.business_domain or "General"} | |
| Always respond with valid JSON format containing the required fields for each topic.""" | |
| def validate_custom_prompt(self, custom_prompt: str) -> PromptValidationResult: | |
| """Validate a custom prompt before processing.""" | |
| return self._validate_prompt(custom_prompt, PromptConfiguration()) | |
| def get_available_templates(self) -> List[Dict[str, str]]: | |
| """Get list of available prompt templates.""" | |
| templates = [] | |
| for template in PromptTemplate: | |
| if template != PromptTemplate.CUSTOM: | |
| templates.append({ | |
| "name": template.value, | |
| "description": self.template_manager.get_template_description(template), | |
| "supported_languages": [lang.value for lang in LanguageCode if lang != LanguageCode.AUTO_DETECT] | |
| }) | |
| return templates | |
| def get_template_variables(self) -> Dict[str, str]: | |
| """Get available template variables and their descriptions.""" | |
| return self.TEMPLATE_VARIABLES.copy() | |
| def preview_template( | |
| self, | |
| template: PromptTemplate, | |
| language: LanguageCode = LanguageCode.ENGLISH, | |
| business_domain: Optional[str] = None, | |
| variables: Optional[Dict[str, str]] = None | |
| ) -> str: | |
| """Preview a template with optional variable substitution.""" | |
| try: | |
| # Get base template | |
| base_template = self.template_manager.get_template( | |
| template, language, business_domain | |
| ) | |
| # Apply variables if provided | |
| if variables: | |
| base_template = self._substitute_variables(base_template, variables) | |
| return base_template | |
| except Exception as e: | |
| self.logger.error(f"Error previewing template: {str(e)}") | |
| raise PromptValidationError(f"Failed to preview template: {str(e)}") | |
| def get_processing_stats(self) -> Dict[str, Any]: | |
| """Get prompt processing statistics.""" | |
| return { | |
| "total_prompts_processed": self.stats["prompts_processed"], | |
| "templates_usage": self.stats["templates_used"], | |
| "validation_failure_rate": ( | |
| self.stats["validation_failures"] / max(1, self.stats["prompts_processed"]) | |
| ), | |
| "safety_warning_rate": ( | |
| self.stats["safety_warnings"] / max(1, self.stats["prompts_processed"]) | |
| ), | |
| "cache_size": len(self._template_cache) | |
| } | |
| def clear_cache(self) -> None: | |
| """Clear the template cache.""" | |
| self._template_cache.clear() | |
| self.logger.info("Template cache cleared") | |
| def reset_stats(self) -> None: | |
| """Reset processing statistics.""" | |
| self.stats = { | |
| "prompts_processed": 0, | |
| "templates_used": {}, | |
| "validation_failures": 0, | |
| "safety_warnings": 0 | |
| } | |
| self.logger.info("Processing statistics reset") | |
| # Global prompt manager instance | |
| _prompt_manager: Optional[PromptManager] = None | |
| def get_prompt_manager() -> PromptManager: | |
| """Get or create global prompt manager instance.""" | |
| global _prompt_manager | |
| if _prompt_manager is None: | |
| _prompt_manager = PromptManager() | |
| return _prompt_manager | |
| def reset_prompt_manager() -> None: | |
| """Reset global prompt manager instance.""" | |
| global _prompt_manager | |
| _prompt_manager = None |