""" Inference Script for Legal-BERT Risk Analysis Run trained model on new legal clauses """ import torch import json from typing import List, Dict, Any import argparse from model import HierarchicalLegalBERT, LegalBertTokenizer from config import LegalBertConfig def load_trained_model(checkpoint_path: str, config: LegalBertConfig) -> HierarchicalLegalBERT: """Load trained model from checkpoint""" print(f"šŸ“„ Loading model from: {checkpoint_path}") # PyTorch 2.6+ requires weights_only=False for custom classes # This is safe since we control the checkpoint creation checkpoint = torch.load(checkpoint_path, map_location=config.device, weights_only=False) # Get number of risk patterns num_risks = len(checkpoint.get('discovered_patterns', {})) print(f" Model has {num_risks} discovered risk patterns") # CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters # This ensures the model architecture matches the trained model if 'config' in checkpoint: saved_config = checkpoint['config'] hidden_dim = saved_config.hierarchical_hidden_dim num_lstm_layers = saved_config.hierarchical_num_lstm_layers print(f" Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}") else: # Fallback to current config (for backward compatibility) hidden_dim = config.hierarchical_hidden_dim num_lstm_layers = config.hierarchical_num_lstm_layers print(f" āš ļø Warning: No config in checkpoint, using current config") # Initialize model with correct architecture parameters model = HierarchicalLegalBERT( config=config, num_discovered_risks=num_risks, hidden_dim=hidden_dim, num_lstm_layers=num_lstm_layers ) model.load_state_dict(checkpoint['model_state_dict']) model.to(config.device) model.eval() print(f" āœ… Model loaded successfully") return model, checkpoint.get('discovered_patterns', {}) def predict_single_clause( model: HierarchicalLegalBERT, tokenizer: LegalBertTokenizer, clause: str, config: LegalBertConfig ) -> Dict[str, Any]: """Predict risk for a single clause""" # Tokenize encoded = tokenizer.tokenize_clauses([clause], config.max_sequence_length) input_ids = encoded['input_ids'].to(config.device) attention_mask = encoded['attention_mask'].to(config.device) # Predict with torch.no_grad(): outputs = model.forward_single_clause(input_ids, attention_mask) # Get probabilities risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1) predicted_risk = torch.argmax(risk_probs, dim=-1) confidence = torch.max(risk_probs, dim=-1)[0] return { 'clause': clause, 'predicted_risk_id': predicted_risk.cpu().item(), 'confidence': confidence.cpu().item(), 'risk_probabilities': risk_probs.cpu().numpy().tolist(), 'severity_score': outputs['severity_score'].cpu().item(), 'importance_score': outputs['importance_score'].cpu().item() } def predict_document( model: HierarchicalLegalBERT, tokenizer: LegalBertTokenizer, document: List[List[str]], config: LegalBertConfig ) -> Dict[str, Any]: """ Predict risks for a full document with context Args: document: List of sections, each containing list of clauses Example: [ ['clause1', 'clause2'], # Section 1 ['clause3', 'clause4'], # Section 2 ] """ print(f"šŸ“„ Analyzing document with {len(document)} sections...") # Tokenize document structure doc_structure = [] clause_texts = [] for section_idx, section in enumerate(document): section_tokens = [] for clause_idx, clause in enumerate(section): encoded = tokenizer.tokenize_clauses([clause], config.max_sequence_length) section_tokens.append({ 'input_ids': encoded['input_ids'][0], 'attention_mask': encoded['attention_mask'][0] }) clause_texts.append({ 'section': section_idx, 'clause': clause_idx, 'text': clause }) doc_structure.append(section_tokens) # Predict with context results = model.predict_document(doc_structure) # Merge predictions with clause texts for i, pred in enumerate(results['clauses']): pred['text'] = clause_texts[i]['text'] return results def format_prediction_output( prediction: Dict[str, Any], risk_patterns: Dict[str, Any] ) -> str: """Format prediction for display""" risk_id = prediction['predicted_risk_id'] pattern_names = list(risk_patterns.keys()) # Handle both string and integer pattern names if risk_id < len(pattern_names): risk_name = str(pattern_names[risk_id]) risk_info = risk_patterns[pattern_names[risk_id]] # Extract keywords from pattern info if isinstance(risk_info, dict): keywords = ', '.join(risk_info.get('keywords', risk_info.get('top_words', []))[:5]) else: keywords = "N/A" else: risk_name = f"Risk Pattern {risk_id}" keywords = "N/A" output = f""" {'='*70} šŸ“‹ CLAUSE ANALYSIS {'='*70} šŸ“ Clause: {prediction.get('text', prediction.get('clause', 'N/A'))} šŸŽÆ Risk Classification: Pattern: {risk_name} Confidence: {prediction['confidence']:.1%} Keywords: {keywords} šŸ“Š Risk Scores: Severity: {prediction['severity_score']:.2f}/10 Importance: {prediction['importance_score']:.2f}/10 šŸ” Probability Distribution: """ # Show top 3 risk probabilities probs = prediction['risk_probabilities'] # Handle nested list structure (e.g., [[prob1, prob2, ...]]) if isinstance(probs, list) and len(probs) > 0 and isinstance(probs[0], list): probs = probs[0] top_3_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)[:3] for idx in top_3_indices: if idx < len(pattern_names): # Convert pattern name to string and truncate if needed pattern_str = str(pattern_names[idx]) if len(pattern_str) > 40: pattern_str = pattern_str[:37] + "..." output += f" {pattern_str:40s} {probs[idx]:.1%}\n" else: output += f" Risk Pattern {idx:2d} {probs[idx]:.1%}\n" return output def main(): """Main inference function""" parser = argparse.ArgumentParser(description='Legal-BERT Risk Analysis Inference') parser.add_argument('--checkpoint', type=str, default='models/legal_bert/final_model.pt', help='Path to model checkpoint') parser.add_argument('--clause', type=str, help='Single clause to analyze') parser.add_argument('--document', type=str, help='Path to JSON file with document structure') parser.add_argument('--output', type=str, help='Path to save results (JSON)') args = parser.parse_args() print("=" * 70) print("šŸ›ļø LEGAL-BERT RISK ANALYSIS INFERENCE") print("=" * 70) # Initialize config config = LegalBertConfig() print(f"\nšŸ“‹ Configuration:") print(f" Device: {config.device}") print(f" Max sequence length: {config.max_sequence_length}") # Load model model, risk_patterns = load_trained_model(args.checkpoint, config) tokenizer = LegalBertTokenizer(config.bert_model_name) print(f"\nšŸ” Discovered Risk Patterns ({len(risk_patterns)}):") pattern_names = list(risk_patterns.keys()) for name in pattern_names[:5]: # Convert to string for display display_name = str(name) print(f" • {display_name}") if len(risk_patterns) > 5: print(f" ... and {len(risk_patterns) - 5} more") results = [] # Single clause mode if args.clause: print(f"\n" + "="*70) print("MODE: Single Clause Analysis") print("="*70) prediction = predict_single_clause(model, tokenizer, args.clause, config) print(format_prediction_output(prediction, risk_patterns)) results.append(prediction) # Document mode elif args.document: print(f"\n" + "="*70) print("MODE: Full Document Analysis (with context)") print("="*70) # Load document with open(args.document, 'r') as f: doc_data = json.load(f) # Expected format: {"sections": [["clause1", "clause2"], ["clause3"]]} document = doc_data.get('sections', []) prediction = predict_document(model, tokenizer, document, config) print(f"\nšŸ“Š Document Summary:") print(f" Sections: {prediction['summary']['num_sections']}") print(f" Clauses: {prediction['summary']['num_clauses']}") print(f" Average Severity: {prediction['summary']['avg_severity']:.2f}/10") print(f" High Risk Clauses: {prediction['summary']['high_risk_count']}") print(f"\nšŸ“‹ Clause-by-Clause Analysis:") for clause_pred in prediction['clauses']: print(format_prediction_output(clause_pred, risk_patterns)) results = prediction # Demo mode (no arguments) else: print(f"\n" + "="*70) print("MODE: Demo Analysis") print("="*70) print("\nšŸ’” Running demo with sample clauses...") demo_clauses = [ "The party shall indemnify and hold harmless all damages and losses.", "This agreement shall be governed by the laws of the state of California.", "Payment must be made within thirty days of invoice date.", "The licensee must not disclose confidential information to third parties.", "Company shall comply with all applicable laws and regulations." ] for clause in demo_clauses: prediction = predict_single_clause(model, tokenizer, clause, config) print(format_prediction_output(prediction, risk_patterns)) results.append(prediction) # Save results if output path provided if args.output: with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f"\nšŸ’¾ Results saved to: {args.output}") print("\n" + "="*70) print("āœ… INFERENCE COMPLETE") print("="*70) # Usage tips if not args.clause and not args.document: print(f"\nšŸ’” Usage Examples:") print(f'\n Single clause:') print(f' python3 inference.py --clause "The party shall indemnify..."') print(f'\n Full document:') print(f' python3 inference.py --document contract.json') print(f'\n Save results:') print(f' python3 inference.py --clause "..." --output results.json') if __name__ == "__main__": main()