BDR-Agent-Factory / examples /text_classification_example.py
Bader Alabddan
Add comprehensive documentation and implementation framework
3ef5d3c
#!/usr/bin/env python3
"""
Text Classification Capability - Example Implementation
This example demonstrates how to implement the text classification capability
for insurance claim categorization with full governance, explainability, and
audit trail support.
Capability ID: cap_text_classification
Version: 2.1.0
Compliance: GDPR, IFRS17
"""
import os
import json
import hashlib
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
import numpy as np
# Mock imports (replace with actual libraries in production)
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
except ImportError:
print("Warning: transformers not installed. Using mock implementation.")
AutoTokenizer = None
AutoModelForSequenceClassification = None
torch = None
@dataclass
class ClassificationResult:
"""Result of text classification"""
predicted_class: str
confidence: float
all_scores: Dict[str, float]
explanation: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None
audit_id: Optional[str] = None
def to_dict(self):
return asdict(self)
class TextClassificationCapability:
"""
Text Classification Capability Implementation
Categorizes insurance claim descriptions into predefined classes
with explainability and audit trail support.
"""
# Capability metadata
CAPABILITY_ID = "cap_text_classification"
VERSION = "2.1.0"
MODEL_VERSION = "2.1.0-bert-large-20260103"
# Insurance claim classes
CLAIM_CLASSES = [
"property_damage",
"auto_accident",
"health_claim",
"liability",
"workers_compensation",
"life_insurance",
"disability",
"other"
]
# Configuration
MAX_INPUT_LENGTH = 10000
DEFAULT_CONFIDENCE_THRESHOLD = 0.7
def __init__(self, model_path: Optional[str] = None, enable_audit: bool = True):
"""
Initialize text classification capability
Args:
model_path: Path to trained model (optional)
enable_audit: Enable audit trail logging
"""
self.model_path = model_path
self.enable_audit = enable_audit
self.audit_records = []
# Load model
self._load_model()
print(f"Initialized {self.CAPABILITY_ID} v{self.VERSION}")
def _load_model(self):
"""Load the classification model"""
if AutoTokenizer and AutoModelForSequenceClassification and torch:
# Production: Load actual BERT model
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path or "bert-large-uncased"
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_path or "bert-large-uncased",
num_labels=len(self.CLAIM_CLASSES)
)
self.model.eval()
print("Loaded production BERT model")
except Exception as e:
print(f"Failed to load model: {e}. Using mock implementation.")
self._use_mock_model()
else:
# Development: Use mock model
self._use_mock_model()
def _use_mock_model(self):
"""Use mock model for demonstration"""
self.tokenizer = None
self.model = None
print("Using mock model for demonstration")
def classify(
self,
text: str,
classes: Optional[List[str]] = None,
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
explain: bool = True,
audit_trail: bool = True,
request_id: Optional[str] = None,
user_id: Optional[str] = None
) -> ClassificationResult:
"""
Classify insurance claim text
Args:
text: Claim description text
classes: Optional list of classes to consider (default: all)
confidence_threshold: Minimum confidence threshold
explain: Generate explanation for prediction
audit_trail: Create audit trail record
request_id: Optional request identifier
user_id: Optional user identifier
Returns:
ClassificationResult with prediction and metadata
Raises:
ValueError: If input is invalid
"""
# Validate input
self._validate_input(text)
# Use default classes if not specified
if classes is None:
classes = self.CLAIM_CLASSES
# Generate request ID if not provided
if request_id is None:
request_id = self._generate_request_id(text)
# Perform classification
start_time = datetime.utcnow()
if self.model is not None:
# Production: Use actual model
scores = self._classify_with_model(text, classes)
else:
# Development: Use mock classification
scores = self._mock_classify(text, classes)
# Get prediction
predicted_class = max(scores, key=scores.get)
confidence = scores[predicted_class]
# Check confidence threshold
if confidence < confidence_threshold:
predicted_class = "uncertain"
# Generate explanation if requested
explanation = None
if explain:
explanation = self._generate_explanation(text, predicted_class, scores)
# Calculate processing time
processing_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
# Create metadata
metadata = {
"capability_id": self.CAPABILITY_ID,
"version": self.VERSION,
"model_version": self.MODEL_VERSION,
"processing_time_ms": processing_time_ms,
"timestamp": datetime.utcnow().isoformat(),
"request_id": request_id,
"compliance_flags": {
"explainable": explain,
"auditable": audit_trail,
"gdpr_compliant": True,
"ifrs17_compliant": True
}
}
# Create audit trail if requested
audit_id = None
if audit_trail and self.enable_audit:
audit_id = self._create_audit_trail(
request_id=request_id,
user_id=user_id,
input_text=text,
predicted_class=predicted_class,
confidence=confidence,
metadata=metadata
)
# Create result
result = ClassificationResult(
predicted_class=predicted_class,
confidence=confidence,
all_scores=scores,
explanation=explanation,
metadata=metadata,
audit_id=audit_id
)
return result
def _validate_input(self, text: str):
"""Validate input text"""
if not text or not isinstance(text, str):
raise ValueError("Input text must be a non-empty string")
if len(text) > self.MAX_INPUT_LENGTH:
raise ValueError(
f"Input text exceeds maximum length of {self.MAX_INPUT_LENGTH} characters"
)
# Check for potentially malicious content
malicious_patterns = ['<script', 'javascript:', 'onerror=']
text_lower = text.lower()
for pattern in malicious_patterns:
if pattern in text_lower:
raise ValueError("Input contains potentially malicious content")
def _classify_with_model(self, text: str, classes: List[str]) -> Dict[str, float]:
"""Classify using actual BERT model"""
# Tokenize input
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0]
# Create scores dictionary
scores = {}
for i, class_name in enumerate(classes):
if i < len(probabilities):
scores[class_name] = float(probabilities[i])
else:
scores[class_name] = 0.0
return scores
def _mock_classify(self, text: str, classes: List[str]) -> Dict[str, float]:
"""Mock classification for demonstration"""
text_lower = text.lower()
# Simple keyword-based classification
scores = {class_name: 0.1 for class_name in classes}
# Property damage keywords
if any(word in text_lower for word in ['water', 'fire', 'damage', 'basement', 'roof', 'storm']):
scores['property_damage'] = 0.92
# Auto accident keywords
elif any(word in text_lower for word in ['collision', 'accident', 'car', 'vehicle', 'highway', 'crash']):
scores['auto_accident'] = 0.88
# Health claim keywords
elif any(word in text_lower for word in ['medical', 'hospital', 'surgery', 'treatment', 'doctor']):
scores['health_claim'] = 0.85
# Liability keywords
elif any(word in text_lower for word in ['slip', 'fall', 'injury', 'lawsuit', 'negligence']):
scores['liability'] = 0.83
# Workers compensation keywords
elif any(word in text_lower for word in ['workplace', 'work injury', 'on the job', 'employee']):
scores['workers_compensation'] = 0.86
# Default to other
else:
scores['other'] = 0.75
# Normalize scores to sum to 1.0
total = sum(scores.values())
scores = {k: v / total for k, v in scores.items()}
return scores
def _generate_explanation(self, text: str, predicted_class: str, scores: Dict[str, float]) -> Dict[str, Any]:
"""Generate explanation for prediction using SHAP-like approach"""
# Extract key features (words) that influenced the decision
words = text.lower().split()
# Mock feature importance (in production, use SHAP or LIME)
feature_importance = []
# Keywords associated with each class
class_keywords = {
'property_damage': ['water', 'fire', 'damage', 'basement', 'roof', 'storm', 'flood'],
'auto_accident': ['collision', 'accident', 'car', 'vehicle', 'highway', 'crash', 'rear-end'],
'health_claim': ['medical', 'hospital', 'surgery', 'treatment', 'doctor', 'patient'],
'liability': ['slip', 'fall', 'injury', 'lawsuit', 'negligence', 'premises'],
'workers_compensation': ['workplace', 'work', 'job', 'employee', 'occupational'],
}
# Find matching keywords
if predicted_class in class_keywords:
for word in words:
if word in class_keywords[predicted_class]:
# Mock importance score
importance = 0.3 + (hash(word) % 30) / 100
feature_importance.append({
'feature': word,
'importance': round(importance, 2),
'contribution': f"+{round(importance * scores[predicted_class], 2)}"
})
# Sort by importance
feature_importance.sort(key=lambda x: x['importance'], reverse=True)
# Take top 5 features
feature_importance = feature_importance[:5]
explanation = {
'method': 'SHAP',
'global_explanation': {
'model_type': 'transformer',
'training_data_size': 100000,
'feature_importance_method': 'attention_weights'
},
'local_explanation': {
'input_text': text[:100] + '...' if len(text) > 100 else text,
'prediction': predicted_class,
'confidence': scores[predicted_class],
'key_features': feature_importance,
'counterfactual': self._generate_counterfactual(text, predicted_class, scores)
},
'human_readable_summary': self._generate_human_summary(
predicted_class, scores[predicted_class], feature_importance
)
}
return explanation
def _generate_counterfactual(self, text: str, predicted_class: str, scores: Dict[str, float]) -> str:
"""Generate counterfactual explanation"""
# Find second-best class
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
if len(sorted_scores) > 1:
second_class, second_score = sorted_scores[1]
return (
f"If key terms were changed, the prediction might be '{second_class}' "
f"with {second_score:.2f} confidence instead."
)
return "No alternative classification available."
def _generate_human_summary(self, predicted_class: str, confidence: float, features: List[Dict]) -> str:
"""Generate human-readable explanation summary"""
if not features:
return f"Classified as '{predicted_class}' with {confidence:.1%} confidence."
feature_text = ", ".join([f"'{f['feature']}' ({f['importance']:.0%})" for f in features[:3]])
return (
f"The model classified this as '{predicted_class}' with {confidence:.1%} confidence "
f"primarily because of the keywords: {feature_text}. These terms are strongly "
f"associated with {predicted_class} claims in the training data."
)
def _generate_request_id(self, text: str) -> str:
"""Generate unique request ID"""
timestamp = datetime.utcnow().isoformat()
content = f"{timestamp}:{text[:100]}"
hash_value = hashlib.sha256(content.encode()).hexdigest()[:16]
return f"req_{hash_value}"
def _create_audit_trail(
self,
request_id: str,
user_id: Optional[str],
input_text: str,
predicted_class: str,
confidence: float,
metadata: Dict[str, Any]
) -> str:
"""Create audit trail record"""
# Generate audit ID
audit_id = f"audit_{hashlib.sha256(request_id.encode()).hexdigest()[:16]}"
# Create audit record
audit_record = {
'audit_id': audit_id,
'timestamp': datetime.utcnow().isoformat(),
'capability_id': self.CAPABILITY_ID,
'version': self.VERSION,
'request_id': request_id,
'user_id': user_id or 'anonymous',
'input_hash': hashlib.sha256(input_text.encode()).hexdigest(),
'output': {
'predicted_class': predicted_class,
'confidence': confidence
},
'output_hash': hashlib.sha256(f"{predicted_class}:{confidence}".encode()).hexdigest(),
'metadata': metadata,
'compliance_flags': metadata['compliance_flags'],
'retention_until': self._calculate_retention_date()
}
# Store audit record (in production, save to database)
self.audit_records.append(audit_record)
return audit_id
def _calculate_retention_date(self) -> str:
"""Calculate data retention date (7 years for GDPR/IFRS17)"""
from datetime import timedelta
retention_date = datetime.utcnow() + timedelta(days=2555) # ~7 years
return retention_date.isoformat()
def get_audit_record(self, audit_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve audit record by ID"""
for record in self.audit_records:
if record['audit_id'] == audit_id:
return record
return None
def batch_classify(
self,
texts: List[str],
**kwargs
) -> List[ClassificationResult]:
"""Classify multiple texts in batch"""
results = []
for text in texts:
try:
result = self.classify(text, **kwargs)
results.append(result)
except Exception as e:
# Create error result
results.append(ClassificationResult(
predicted_class='error',
confidence=0.0,
all_scores={},
explanation={'error': str(e)}
))
return results
def main():
"""Example usage of text classification capability"""
print("=" * 70)
print("Text Classification Capability - Example Usage")
print("=" * 70)
print()
# Initialize capability
classifier = TextClassificationCapability(enable_audit=True)
print()
# Example 1: Property damage claim
print("Example 1: Property Damage Claim")
print("-" * 70)
claim_text_1 = "Customer reported water damage to basement after heavy rain. The carpet is soaked and there's visible mold on the walls."
print(f"Input: {claim_text_1}")
print()
result_1 = classifier.classify(
text=claim_text_1,
explain=True,
audit_trail=True,
user_id="user_123"
)
print(f"Predicted Class: {result_1.predicted_class}")
print(f"Confidence: {result_1.confidence:.2%}")
print(f"Processing Time: {result_1.metadata['processing_time_ms']:.2f}ms")
print(f"Audit ID: {result_1.audit_id}")
print()
print("Explanation:")
print(result_1.explanation['human_readable_summary'])
print()
print("Top Features:")
for feature in result_1.explanation['local_explanation']['key_features'][:3]:
print(f" - {feature['feature']}: {feature['importance']:.0%} importance")
print()
print()
# Example 2: Auto accident claim
print("Example 2: Auto Accident Claim")
print("-" * 70)
claim_text_2 = "Rear-end collision on I-5 highway during rush hour. Vehicle sustained damage to rear bumper and trunk."
print(f"Input: {claim_text_2}")
print()
result_2 = classifier.classify(
text=claim_text_2,
explain=True,
audit_trail=True,
user_id="user_456"
)
print(f"Predicted Class: {result_2.predicted_class}")
print(f"Confidence: {result_2.confidence:.2%}")
print()
print("All Scores:")
for class_name, score in sorted(result_2.all_scores.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" - {class_name}: {score:.2%}")
print()
print()
# Example 3: Batch classification
print("Example 3: Batch Classification")
print("-" * 70)
batch_texts = [
"Patient underwent surgery for knee replacement at local hospital.",
"Employee injured on the job while operating machinery.",
"Customer slipped and fell in grocery store parking lot."
]
print(f"Processing {len(batch_texts)} claims...")
batch_results = classifier.batch_classify(
texts=batch_texts,
explain=False,
audit_trail=True
)
print()
for i, result in enumerate(batch_results, 1):
print(f"{i}. {result.predicted_class} ({result.confidence:.2%})")
print()
print()
# Example 4: Retrieve audit record
print("Example 4: Audit Trail Retrieval")
print("-" * 70)
audit_record = classifier.get_audit_record(result_1.audit_id)
if audit_record:
print(f"Audit ID: {audit_record['audit_id']}")
print(f"Timestamp: {audit_record['timestamp']}")
print(f"User ID: {audit_record['user_id']}")
print(f"Input Hash: {audit_record['input_hash'][:32]}...")
print(f"Output Hash: {audit_record['output_hash'][:32]}...")
print(f"Retention Until: {audit_record['retention_until'][:10]}")
print(f"GDPR Compliant: {audit_record['compliance_flags']['gdpr_compliant']}")
print(f"IFRS17 Compliant: {audit_record['compliance_flags']['ifrs17_compliant']}")
print()
print()
# Example 5: Export results as JSON
print("Example 5: JSON Export")
print("-" * 70)
result_json = json.dumps(result_1.to_dict(), indent=2)
print(result_json[:500] + "...")
print()
print("=" * 70)
print("Examples completed successfully!")
print("=" * 70)
if __name__ == "__main__":
main()