#!/usr/bin/env python3 """ Text Classification Capability - Example Implementation This example demonstrates how to implement the text classification capability for insurance claim categorization with full governance, explainability, and audit trail support. Capability ID: cap_text_classification Version: 2.1.0 Compliance: GDPR, IFRS17 """ import os import json import hashlib from datetime import datetime from typing import Dict, List, Optional, Any from dataclasses import dataclass, asdict import numpy as np # Mock imports (replace with actual libraries in production) try: from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch except ImportError: print("Warning: transformers not installed. Using mock implementation.") AutoTokenizer = None AutoModelForSequenceClassification = None torch = None @dataclass class ClassificationResult: """Result of text classification""" predicted_class: str confidence: float all_scores: Dict[str, float] explanation: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None audit_id: Optional[str] = None def to_dict(self): return asdict(self) class TextClassificationCapability: """ Text Classification Capability Implementation Categorizes insurance claim descriptions into predefined classes with explainability and audit trail support. """ # Capability metadata CAPABILITY_ID = "cap_text_classification" VERSION = "2.1.0" MODEL_VERSION = "2.1.0-bert-large-20260103" # Insurance claim classes CLAIM_CLASSES = [ "property_damage", "auto_accident", "health_claim", "liability", "workers_compensation", "life_insurance", "disability", "other" ] # Configuration MAX_INPUT_LENGTH = 10000 DEFAULT_CONFIDENCE_THRESHOLD = 0.7 def __init__(self, model_path: Optional[str] = None, enable_audit: bool = True): """ Initialize text classification capability Args: model_path: Path to trained model (optional) enable_audit: Enable audit trail logging """ self.model_path = model_path self.enable_audit = enable_audit self.audit_records = [] # Load model self._load_model() print(f"Initialized {self.CAPABILITY_ID} v{self.VERSION}") def _load_model(self): """Load the classification model""" if AutoTokenizer and AutoModelForSequenceClassification and torch: # Production: Load actual BERT model try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_path or "bert-large-uncased" ) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_path or "bert-large-uncased", num_labels=len(self.CLAIM_CLASSES) ) self.model.eval() print("Loaded production BERT model") except Exception as e: print(f"Failed to load model: {e}. Using mock implementation.") self._use_mock_model() else: # Development: Use mock model self._use_mock_model() def _use_mock_model(self): """Use mock model for demonstration""" self.tokenizer = None self.model = None print("Using mock model for demonstration") def classify( self, text: str, classes: Optional[List[str]] = None, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, explain: bool = True, audit_trail: bool = True, request_id: Optional[str] = None, user_id: Optional[str] = None ) -> ClassificationResult: """ Classify insurance claim text Args: text: Claim description text classes: Optional list of classes to consider (default: all) confidence_threshold: Minimum confidence threshold explain: Generate explanation for prediction audit_trail: Create audit trail record request_id: Optional request identifier user_id: Optional user identifier Returns: ClassificationResult with prediction and metadata Raises: ValueError: If input is invalid """ # Validate input self._validate_input(text) # Use default classes if not specified if classes is None: classes = self.CLAIM_CLASSES # Generate request ID if not provided if request_id is None: request_id = self._generate_request_id(text) # Perform classification start_time = datetime.utcnow() if self.model is not None: # Production: Use actual model scores = self._classify_with_model(text, classes) else: # Development: Use mock classification scores = self._mock_classify(text, classes) # Get prediction predicted_class = max(scores, key=scores.get) confidence = scores[predicted_class] # Check confidence threshold if confidence < confidence_threshold: predicted_class = "uncertain" # Generate explanation if requested explanation = None if explain: explanation = self._generate_explanation(text, predicted_class, scores) # Calculate processing time processing_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000 # Create metadata metadata = { "capability_id": self.CAPABILITY_ID, "version": self.VERSION, "model_version": self.MODEL_VERSION, "processing_time_ms": processing_time_ms, "timestamp": datetime.utcnow().isoformat(), "request_id": request_id, "compliance_flags": { "explainable": explain, "auditable": audit_trail, "gdpr_compliant": True, "ifrs17_compliant": True } } # Create audit trail if requested audit_id = None if audit_trail and self.enable_audit: audit_id = self._create_audit_trail( request_id=request_id, user_id=user_id, input_text=text, predicted_class=predicted_class, confidence=confidence, metadata=metadata ) # Create result result = ClassificationResult( predicted_class=predicted_class, confidence=confidence, all_scores=scores, explanation=explanation, metadata=metadata, audit_id=audit_id ) return result def _validate_input(self, text: str): """Validate input text""" if not text or not isinstance(text, str): raise ValueError("Input text must be a non-empty string") if len(text) > self.MAX_INPUT_LENGTH: raise ValueError( f"Input text exceeds maximum length of {self.MAX_INPUT_LENGTH} characters" ) # Check for potentially malicious content malicious_patterns = [' Dict[str, Any]: """Generate explanation for prediction using SHAP-like approach""" # Extract key features (words) that influenced the decision words = text.lower().split() # Mock feature importance (in production, use SHAP or LIME) feature_importance = [] # Keywords associated with each class class_keywords = { 'property_damage': ['water', 'fire', 'damage', 'basement', 'roof', 'storm', 'flood'], 'auto_accident': ['collision', 'accident', 'car', 'vehicle', 'highway', 'crash', 'rear-end'], 'health_claim': ['medical', 'hospital', 'surgery', 'treatment', 'doctor', 'patient'], 'liability': ['slip', 'fall', 'injury', 'lawsuit', 'negligence', 'premises'], 'workers_compensation': ['workplace', 'work', 'job', 'employee', 'occupational'], } # Find matching keywords if predicted_class in class_keywords: for word in words: if word in class_keywords[predicted_class]: # Mock importance score importance = 0.3 + (hash(word) % 30) / 100 feature_importance.append({ 'feature': word, 'importance': round(importance, 2), 'contribution': f"+{round(importance * scores[predicted_class], 2)}" }) # Sort by importance feature_importance.sort(key=lambda x: x['importance'], reverse=True) # Take top 5 features feature_importance = feature_importance[:5] explanation = { 'method': 'SHAP', 'global_explanation': { 'model_type': 'transformer', 'training_data_size': 100000, 'feature_importance_method': 'attention_weights' }, 'local_explanation': { 'input_text': text[:100] + '...' if len(text) > 100 else text, 'prediction': predicted_class, 'confidence': scores[predicted_class], 'key_features': feature_importance, 'counterfactual': self._generate_counterfactual(text, predicted_class, scores) }, 'human_readable_summary': self._generate_human_summary( predicted_class, scores[predicted_class], feature_importance ) } return explanation def _generate_counterfactual(self, text: str, predicted_class: str, scores: Dict[str, float]) -> str: """Generate counterfactual explanation""" # Find second-best class sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) if len(sorted_scores) > 1: second_class, second_score = sorted_scores[1] return ( f"If key terms were changed, the prediction might be '{second_class}' " f"with {second_score:.2f} confidence instead." ) return "No alternative classification available." def _generate_human_summary(self, predicted_class: str, confidence: float, features: List[Dict]) -> str: """Generate human-readable explanation summary""" if not features: return f"Classified as '{predicted_class}' with {confidence:.1%} confidence." feature_text = ", ".join([f"'{f['feature']}' ({f['importance']:.0%})" for f in features[:3]]) return ( f"The model classified this as '{predicted_class}' with {confidence:.1%} confidence " f"primarily because of the keywords: {feature_text}. These terms are strongly " f"associated with {predicted_class} claims in the training data." ) def _generate_request_id(self, text: str) -> str: """Generate unique request ID""" timestamp = datetime.utcnow().isoformat() content = f"{timestamp}:{text[:100]}" hash_value = hashlib.sha256(content.encode()).hexdigest()[:16] return f"req_{hash_value}" def _create_audit_trail( self, request_id: str, user_id: Optional[str], input_text: str, predicted_class: str, confidence: float, metadata: Dict[str, Any] ) -> str: """Create audit trail record""" # Generate audit ID audit_id = f"audit_{hashlib.sha256(request_id.encode()).hexdigest()[:16]}" # Create audit record audit_record = { 'audit_id': audit_id, 'timestamp': datetime.utcnow().isoformat(), 'capability_id': self.CAPABILITY_ID, 'version': self.VERSION, 'request_id': request_id, 'user_id': user_id or 'anonymous', 'input_hash': hashlib.sha256(input_text.encode()).hexdigest(), 'output': { 'predicted_class': predicted_class, 'confidence': confidence }, 'output_hash': hashlib.sha256(f"{predicted_class}:{confidence}".encode()).hexdigest(), 'metadata': metadata, 'compliance_flags': metadata['compliance_flags'], 'retention_until': self._calculate_retention_date() } # Store audit record (in production, save to database) self.audit_records.append(audit_record) return audit_id def _calculate_retention_date(self) -> str: """Calculate data retention date (7 years for GDPR/IFRS17)""" from datetime import timedelta retention_date = datetime.utcnow() + timedelta(days=2555) # ~7 years return retention_date.isoformat() def get_audit_record(self, audit_id: str) -> Optional[Dict[str, Any]]: """Retrieve audit record by ID""" for record in self.audit_records: if record['audit_id'] == audit_id: return record return None def batch_classify( self, texts: List[str], **kwargs ) -> List[ClassificationResult]: """Classify multiple texts in batch""" results = [] for text in texts: try: result = self.classify(text, **kwargs) results.append(result) except Exception as e: # Create error result results.append(ClassificationResult( predicted_class='error', confidence=0.0, all_scores={}, explanation={'error': str(e)} )) return results def main(): """Example usage of text classification capability""" print("=" * 70) print("Text Classification Capability - Example Usage") print("=" * 70) print() # Initialize capability classifier = TextClassificationCapability(enable_audit=True) print() # Example 1: Property damage claim print("Example 1: Property Damage Claim") print("-" * 70) claim_text_1 = "Customer reported water damage to basement after heavy rain. The carpet is soaked and there's visible mold on the walls." print(f"Input: {claim_text_1}") print() result_1 = classifier.classify( text=claim_text_1, explain=True, audit_trail=True, user_id="user_123" ) print(f"Predicted Class: {result_1.predicted_class}") print(f"Confidence: {result_1.confidence:.2%}") print(f"Processing Time: {result_1.metadata['processing_time_ms']:.2f}ms") print(f"Audit ID: {result_1.audit_id}") print() print("Explanation:") print(result_1.explanation['human_readable_summary']) print() print("Top Features:") for feature in result_1.explanation['local_explanation']['key_features'][:3]: print(f" - {feature['feature']}: {feature['importance']:.0%} importance") print() print() # Example 2: Auto accident claim print("Example 2: Auto Accident Claim") print("-" * 70) claim_text_2 = "Rear-end collision on I-5 highway during rush hour. Vehicle sustained damage to rear bumper and trunk." print(f"Input: {claim_text_2}") print() result_2 = classifier.classify( text=claim_text_2, explain=True, audit_trail=True, user_id="user_456" ) print(f"Predicted Class: {result_2.predicted_class}") print(f"Confidence: {result_2.confidence:.2%}") print() print("All Scores:") for class_name, score in sorted(result_2.all_scores.items(), key=lambda x: x[1], reverse=True)[:5]: print(f" - {class_name}: {score:.2%}") print() print() # Example 3: Batch classification print("Example 3: Batch Classification") print("-" * 70) batch_texts = [ "Patient underwent surgery for knee replacement at local hospital.", "Employee injured on the job while operating machinery.", "Customer slipped and fell in grocery store parking lot." ] print(f"Processing {len(batch_texts)} claims...") batch_results = classifier.batch_classify( texts=batch_texts, explain=False, audit_trail=True ) print() for i, result in enumerate(batch_results, 1): print(f"{i}. {result.predicted_class} ({result.confidence:.2%})") print() print() # Example 4: Retrieve audit record print("Example 4: Audit Trail Retrieval") print("-" * 70) audit_record = classifier.get_audit_record(result_1.audit_id) if audit_record: print(f"Audit ID: {audit_record['audit_id']}") print(f"Timestamp: {audit_record['timestamp']}") print(f"User ID: {audit_record['user_id']}") print(f"Input Hash: {audit_record['input_hash'][:32]}...") print(f"Output Hash: {audit_record['output_hash'][:32]}...") print(f"Retention Until: {audit_record['retention_until'][:10]}") print(f"GDPR Compliant: {audit_record['compliance_flags']['gdpr_compliant']}") print(f"IFRS17 Compliant: {audit_record['compliance_flags']['ifrs17_compliant']}") print() print() # Example 5: Export results as JSON print("Example 5: JSON Export") print("-" * 70) result_json = json.dumps(result_1.to_dict(), indent=2) print(result_json[:500] + "...") print() print("=" * 70) print("Examples completed successfully!") print("=" * 70) if __name__ == "__main__": main()