Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Text Classification Capability - Example Implementation | |
| This example demonstrates how to implement the text classification capability | |
| for insurance claim categorization with full governance, explainability, and | |
| audit trail support. | |
| Capability ID: cap_text_classification | |
| Version: 2.1.0 | |
| Compliance: GDPR, IFRS17 | |
| """ | |
| import os | |
| import json | |
| import hashlib | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any | |
| from dataclasses import dataclass, asdict | |
| import numpy as np | |
| # Mock imports (replace with actual libraries in production) | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| except ImportError: | |
| print("Warning: transformers not installed. Using mock implementation.") | |
| AutoTokenizer = None | |
| AutoModelForSequenceClassification = None | |
| torch = None | |
| class ClassificationResult: | |
| """Result of text classification""" | |
| predicted_class: str | |
| confidence: float | |
| all_scores: Dict[str, float] | |
| explanation: Optional[Dict[str, Any]] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| audit_id: Optional[str] = None | |
| def to_dict(self): | |
| return asdict(self) | |
| class TextClassificationCapability: | |
| """ | |
| Text Classification Capability Implementation | |
| Categorizes insurance claim descriptions into predefined classes | |
| with explainability and audit trail support. | |
| """ | |
| # Capability metadata | |
| CAPABILITY_ID = "cap_text_classification" | |
| VERSION = "2.1.0" | |
| MODEL_VERSION = "2.1.0-bert-large-20260103" | |
| # Insurance claim classes | |
| CLAIM_CLASSES = [ | |
| "property_damage", | |
| "auto_accident", | |
| "health_claim", | |
| "liability", | |
| "workers_compensation", | |
| "life_insurance", | |
| "disability", | |
| "other" | |
| ] | |
| # Configuration | |
| MAX_INPUT_LENGTH = 10000 | |
| DEFAULT_CONFIDENCE_THRESHOLD = 0.7 | |
| def __init__(self, model_path: Optional[str] = None, enable_audit: bool = True): | |
| """ | |
| Initialize text classification capability | |
| Args: | |
| model_path: Path to trained model (optional) | |
| enable_audit: Enable audit trail logging | |
| """ | |
| self.model_path = model_path | |
| self.enable_audit = enable_audit | |
| self.audit_records = [] | |
| # Load model | |
| self._load_model() | |
| print(f"Initialized {self.CAPABILITY_ID} v{self.VERSION}") | |
| def _load_model(self): | |
| """Load the classification model""" | |
| if AutoTokenizer and AutoModelForSequenceClassification and torch: | |
| # Production: Load actual BERT model | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_path or "bert-large-uncased" | |
| ) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_path or "bert-large-uncased", | |
| num_labels=len(self.CLAIM_CLASSES) | |
| ) | |
| self.model.eval() | |
| print("Loaded production BERT model") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}. Using mock implementation.") | |
| self._use_mock_model() | |
| else: | |
| # Development: Use mock model | |
| self._use_mock_model() | |
| def _use_mock_model(self): | |
| """Use mock model for demonstration""" | |
| self.tokenizer = None | |
| self.model = None | |
| print("Using mock model for demonstration") | |
| def classify( | |
| self, | |
| text: str, | |
| classes: Optional[List[str]] = None, | |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | |
| explain: bool = True, | |
| audit_trail: bool = True, | |
| request_id: Optional[str] = None, | |
| user_id: Optional[str] = None | |
| ) -> ClassificationResult: | |
| """ | |
| Classify insurance claim text | |
| Args: | |
| text: Claim description text | |
| classes: Optional list of classes to consider (default: all) | |
| confidence_threshold: Minimum confidence threshold | |
| explain: Generate explanation for prediction | |
| audit_trail: Create audit trail record | |
| request_id: Optional request identifier | |
| user_id: Optional user identifier | |
| Returns: | |
| ClassificationResult with prediction and metadata | |
| Raises: | |
| ValueError: If input is invalid | |
| """ | |
| # Validate input | |
| self._validate_input(text) | |
| # Use default classes if not specified | |
| if classes is None: | |
| classes = self.CLAIM_CLASSES | |
| # Generate request ID if not provided | |
| if request_id is None: | |
| request_id = self._generate_request_id(text) | |
| # Perform classification | |
| start_time = datetime.utcnow() | |
| if self.model is not None: | |
| # Production: Use actual model | |
| scores = self._classify_with_model(text, classes) | |
| else: | |
| # Development: Use mock classification | |
| scores = self._mock_classify(text, classes) | |
| # Get prediction | |
| predicted_class = max(scores, key=scores.get) | |
| confidence = scores[predicted_class] | |
| # Check confidence threshold | |
| if confidence < confidence_threshold: | |
| predicted_class = "uncertain" | |
| # Generate explanation if requested | |
| explanation = None | |
| if explain: | |
| explanation = self._generate_explanation(text, predicted_class, scores) | |
| # Calculate processing time | |
| processing_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000 | |
| # Create metadata | |
| metadata = { | |
| "capability_id": self.CAPABILITY_ID, | |
| "version": self.VERSION, | |
| "model_version": self.MODEL_VERSION, | |
| "processing_time_ms": processing_time_ms, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "request_id": request_id, | |
| "compliance_flags": { | |
| "explainable": explain, | |
| "auditable": audit_trail, | |
| "gdpr_compliant": True, | |
| "ifrs17_compliant": True | |
| } | |
| } | |
| # Create audit trail if requested | |
| audit_id = None | |
| if audit_trail and self.enable_audit: | |
| audit_id = self._create_audit_trail( | |
| request_id=request_id, | |
| user_id=user_id, | |
| input_text=text, | |
| predicted_class=predicted_class, | |
| confidence=confidence, | |
| metadata=metadata | |
| ) | |
| # Create result | |
| result = ClassificationResult( | |
| predicted_class=predicted_class, | |
| confidence=confidence, | |
| all_scores=scores, | |
| explanation=explanation, | |
| metadata=metadata, | |
| audit_id=audit_id | |
| ) | |
| return result | |
| def _validate_input(self, text: str): | |
| """Validate input text""" | |
| if not text or not isinstance(text, str): | |
| raise ValueError("Input text must be a non-empty string") | |
| if len(text) > self.MAX_INPUT_LENGTH: | |
| raise ValueError( | |
| f"Input text exceeds maximum length of {self.MAX_INPUT_LENGTH} characters" | |
| ) | |
| # Check for potentially malicious content | |
| malicious_patterns = ['<script', 'javascript:', 'onerror='] | |
| text_lower = text.lower() | |
| for pattern in malicious_patterns: | |
| if pattern in text_lower: | |
| raise ValueError("Input contains potentially malicious content") | |
| def _classify_with_model(self, text: str, classes: List[str]) -> Dict[str, float]: | |
| """Classify using actual BERT model""" | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0] | |
| # Create scores dictionary | |
| scores = {} | |
| for i, class_name in enumerate(classes): | |
| if i < len(probabilities): | |
| scores[class_name] = float(probabilities[i]) | |
| else: | |
| scores[class_name] = 0.0 | |
| return scores | |
| def _mock_classify(self, text: str, classes: List[str]) -> Dict[str, float]: | |
| """Mock classification for demonstration""" | |
| text_lower = text.lower() | |
| # Simple keyword-based classification | |
| scores = {class_name: 0.1 for class_name in classes} | |
| # Property damage keywords | |
| if any(word in text_lower for word in ['water', 'fire', 'damage', 'basement', 'roof', 'storm']): | |
| scores['property_damage'] = 0.92 | |
| # Auto accident keywords | |
| elif any(word in text_lower for word in ['collision', 'accident', 'car', 'vehicle', 'highway', 'crash']): | |
| scores['auto_accident'] = 0.88 | |
| # Health claim keywords | |
| elif any(word in text_lower for word in ['medical', 'hospital', 'surgery', 'treatment', 'doctor']): | |
| scores['health_claim'] = 0.85 | |
| # Liability keywords | |
| elif any(word in text_lower for word in ['slip', 'fall', 'injury', 'lawsuit', 'negligence']): | |
| scores['liability'] = 0.83 | |
| # Workers compensation keywords | |
| elif any(word in text_lower for word in ['workplace', 'work injury', 'on the job', 'employee']): | |
| scores['workers_compensation'] = 0.86 | |
| # Default to other | |
| else: | |
| scores['other'] = 0.75 | |
| # Normalize scores to sum to 1.0 | |
| total = sum(scores.values()) | |
| scores = {k: v / total for k, v in scores.items()} | |
| return scores | |
| def _generate_explanation(self, text: str, predicted_class: str, scores: Dict[str, float]) -> Dict[str, Any]: | |
| """Generate explanation for prediction using SHAP-like approach""" | |
| # Extract key features (words) that influenced the decision | |
| words = text.lower().split() | |
| # Mock feature importance (in production, use SHAP or LIME) | |
| feature_importance = [] | |
| # Keywords associated with each class | |
| class_keywords = { | |
| 'property_damage': ['water', 'fire', 'damage', 'basement', 'roof', 'storm', 'flood'], | |
| 'auto_accident': ['collision', 'accident', 'car', 'vehicle', 'highway', 'crash', 'rear-end'], | |
| 'health_claim': ['medical', 'hospital', 'surgery', 'treatment', 'doctor', 'patient'], | |
| 'liability': ['slip', 'fall', 'injury', 'lawsuit', 'negligence', 'premises'], | |
| 'workers_compensation': ['workplace', 'work', 'job', 'employee', 'occupational'], | |
| } | |
| # Find matching keywords | |
| if predicted_class in class_keywords: | |
| for word in words: | |
| if word in class_keywords[predicted_class]: | |
| # Mock importance score | |
| importance = 0.3 + (hash(word) % 30) / 100 | |
| feature_importance.append({ | |
| 'feature': word, | |
| 'importance': round(importance, 2), | |
| 'contribution': f"+{round(importance * scores[predicted_class], 2)}" | |
| }) | |
| # Sort by importance | |
| feature_importance.sort(key=lambda x: x['importance'], reverse=True) | |
| # Take top 5 features | |
| feature_importance = feature_importance[:5] | |
| explanation = { | |
| 'method': 'SHAP', | |
| 'global_explanation': { | |
| 'model_type': 'transformer', | |
| 'training_data_size': 100000, | |
| 'feature_importance_method': 'attention_weights' | |
| }, | |
| 'local_explanation': { | |
| 'input_text': text[:100] + '...' if len(text) > 100 else text, | |
| 'prediction': predicted_class, | |
| 'confidence': scores[predicted_class], | |
| 'key_features': feature_importance, | |
| 'counterfactual': self._generate_counterfactual(text, predicted_class, scores) | |
| }, | |
| 'human_readable_summary': self._generate_human_summary( | |
| predicted_class, scores[predicted_class], feature_importance | |
| ) | |
| } | |
| return explanation | |
| def _generate_counterfactual(self, text: str, predicted_class: str, scores: Dict[str, float]) -> str: | |
| """Generate counterfactual explanation""" | |
| # Find second-best class | |
| sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| if len(sorted_scores) > 1: | |
| second_class, second_score = sorted_scores[1] | |
| return ( | |
| f"If key terms were changed, the prediction might be '{second_class}' " | |
| f"with {second_score:.2f} confidence instead." | |
| ) | |
| return "No alternative classification available." | |
| def _generate_human_summary(self, predicted_class: str, confidence: float, features: List[Dict]) -> str: | |
| """Generate human-readable explanation summary""" | |
| if not features: | |
| return f"Classified as '{predicted_class}' with {confidence:.1%} confidence." | |
| feature_text = ", ".join([f"'{f['feature']}' ({f['importance']:.0%})" for f in features[:3]]) | |
| return ( | |
| f"The model classified this as '{predicted_class}' with {confidence:.1%} confidence " | |
| f"primarily because of the keywords: {feature_text}. These terms are strongly " | |
| f"associated with {predicted_class} claims in the training data." | |
| ) | |
| def _generate_request_id(self, text: str) -> str: | |
| """Generate unique request ID""" | |
| timestamp = datetime.utcnow().isoformat() | |
| content = f"{timestamp}:{text[:100]}" | |
| hash_value = hashlib.sha256(content.encode()).hexdigest()[:16] | |
| return f"req_{hash_value}" | |
| def _create_audit_trail( | |
| self, | |
| request_id: str, | |
| user_id: Optional[str], | |
| input_text: str, | |
| predicted_class: str, | |
| confidence: float, | |
| metadata: Dict[str, Any] | |
| ) -> str: | |
| """Create audit trail record""" | |
| # Generate audit ID | |
| audit_id = f"audit_{hashlib.sha256(request_id.encode()).hexdigest()[:16]}" | |
| # Create audit record | |
| audit_record = { | |
| 'audit_id': audit_id, | |
| 'timestamp': datetime.utcnow().isoformat(), | |
| 'capability_id': self.CAPABILITY_ID, | |
| 'version': self.VERSION, | |
| 'request_id': request_id, | |
| 'user_id': user_id or 'anonymous', | |
| 'input_hash': hashlib.sha256(input_text.encode()).hexdigest(), | |
| 'output': { | |
| 'predicted_class': predicted_class, | |
| 'confidence': confidence | |
| }, | |
| 'output_hash': hashlib.sha256(f"{predicted_class}:{confidence}".encode()).hexdigest(), | |
| 'metadata': metadata, | |
| 'compliance_flags': metadata['compliance_flags'], | |
| 'retention_until': self._calculate_retention_date() | |
| } | |
| # Store audit record (in production, save to database) | |
| self.audit_records.append(audit_record) | |
| return audit_id | |
| def _calculate_retention_date(self) -> str: | |
| """Calculate data retention date (7 years for GDPR/IFRS17)""" | |
| from datetime import timedelta | |
| retention_date = datetime.utcnow() + timedelta(days=2555) # ~7 years | |
| return retention_date.isoformat() | |
| def get_audit_record(self, audit_id: str) -> Optional[Dict[str, Any]]: | |
| """Retrieve audit record by ID""" | |
| for record in self.audit_records: | |
| if record['audit_id'] == audit_id: | |
| return record | |
| return None | |
| def batch_classify( | |
| self, | |
| texts: List[str], | |
| **kwargs | |
| ) -> List[ClassificationResult]: | |
| """Classify multiple texts in batch""" | |
| results = [] | |
| for text in texts: | |
| try: | |
| result = self.classify(text, **kwargs) | |
| results.append(result) | |
| except Exception as e: | |
| # Create error result | |
| results.append(ClassificationResult( | |
| predicted_class='error', | |
| confidence=0.0, | |
| all_scores={}, | |
| explanation={'error': str(e)} | |
| )) | |
| return results | |
| def main(): | |
| """Example usage of text classification capability""" | |
| print("=" * 70) | |
| print("Text Classification Capability - Example Usage") | |
| print("=" * 70) | |
| print() | |
| # Initialize capability | |
| classifier = TextClassificationCapability(enable_audit=True) | |
| print() | |
| # Example 1: Property damage claim | |
| print("Example 1: Property Damage Claim") | |
| print("-" * 70) | |
| claim_text_1 = "Customer reported water damage to basement after heavy rain. The carpet is soaked and there's visible mold on the walls." | |
| print(f"Input: {claim_text_1}") | |
| print() | |
| result_1 = classifier.classify( | |
| text=claim_text_1, | |
| explain=True, | |
| audit_trail=True, | |
| user_id="user_123" | |
| ) | |
| print(f"Predicted Class: {result_1.predicted_class}") | |
| print(f"Confidence: {result_1.confidence:.2%}") | |
| print(f"Processing Time: {result_1.metadata['processing_time_ms']:.2f}ms") | |
| print(f"Audit ID: {result_1.audit_id}") | |
| print() | |
| print("Explanation:") | |
| print(result_1.explanation['human_readable_summary']) | |
| print() | |
| print("Top Features:") | |
| for feature in result_1.explanation['local_explanation']['key_features'][:3]: | |
| print(f" - {feature['feature']}: {feature['importance']:.0%} importance") | |
| print() | |
| print() | |
| # Example 2: Auto accident claim | |
| print("Example 2: Auto Accident Claim") | |
| print("-" * 70) | |
| claim_text_2 = "Rear-end collision on I-5 highway during rush hour. Vehicle sustained damage to rear bumper and trunk." | |
| print(f"Input: {claim_text_2}") | |
| print() | |
| result_2 = classifier.classify( | |
| text=claim_text_2, | |
| explain=True, | |
| audit_trail=True, | |
| user_id="user_456" | |
| ) | |
| print(f"Predicted Class: {result_2.predicted_class}") | |
| print(f"Confidence: {result_2.confidence:.2%}") | |
| print() | |
| print("All Scores:") | |
| for class_name, score in sorted(result_2.all_scores.items(), key=lambda x: x[1], reverse=True)[:5]: | |
| print(f" - {class_name}: {score:.2%}") | |
| print() | |
| print() | |
| # Example 3: Batch classification | |
| print("Example 3: Batch Classification") | |
| print("-" * 70) | |
| batch_texts = [ | |
| "Patient underwent surgery for knee replacement at local hospital.", | |
| "Employee injured on the job while operating machinery.", | |
| "Customer slipped and fell in grocery store parking lot." | |
| ] | |
| print(f"Processing {len(batch_texts)} claims...") | |
| batch_results = classifier.batch_classify( | |
| texts=batch_texts, | |
| explain=False, | |
| audit_trail=True | |
| ) | |
| print() | |
| for i, result in enumerate(batch_results, 1): | |
| print(f"{i}. {result.predicted_class} ({result.confidence:.2%})") | |
| print() | |
| print() | |
| # Example 4: Retrieve audit record | |
| print("Example 4: Audit Trail Retrieval") | |
| print("-" * 70) | |
| audit_record = classifier.get_audit_record(result_1.audit_id) | |
| if audit_record: | |
| print(f"Audit ID: {audit_record['audit_id']}") | |
| print(f"Timestamp: {audit_record['timestamp']}") | |
| print(f"User ID: {audit_record['user_id']}") | |
| print(f"Input Hash: {audit_record['input_hash'][:32]}...") | |
| print(f"Output Hash: {audit_record['output_hash'][:32]}...") | |
| print(f"Retention Until: {audit_record['retention_until'][:10]}") | |
| print(f"GDPR Compliant: {audit_record['compliance_flags']['gdpr_compliant']}") | |
| print(f"IFRS17 Compliant: {audit_record['compliance_flags']['ifrs17_compliant']}") | |
| print() | |
| print() | |
| # Example 5: Export results as JSON | |
| print("Example 5: JSON Export") | |
| print("-" * 70) | |
| result_json = json.dumps(result_1.to_dict(), indent=2) | |
| print(result_json[:500] + "...") | |
| print() | |
| print("=" * 70) | |
| print("Examples completed successfully!") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| main() | |