| """ |
| Automated validation script for rmtariq/multilingual-emotion-classifier |
| This script runs automated tests and generates a validation report. |
| |
| Usage: |
| python validate_model.py |
| python validate_model.py --output report.txt |
| |
| Author: rmtariq |
| """ |
|
|
| import argparse |
| import json |
| import time |
| from datetime import datetime |
| from transformers import pipeline |
| import torch |
|
|
| def validate_model(model_name="rmtariq/multilingual-emotion-classifier", output_file=None): |
| """Run comprehensive validation and generate report""" |
| |
| print("π AUTOMATED MODEL VALIDATION") |
| print("=" * 60) |
| print(f"Model: {model_name}") |
| print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| print() |
| |
| |
| validation_results = { |
| "model_name": model_name, |
| "timestamp": datetime.now().isoformat(), |
| "device": "GPU" if torch.cuda.is_available() else "CPU", |
| "tests": {}, |
| "overall_status": "UNKNOWN" |
| } |
| |
| try: |
| |
| print("π₯ Loading model...") |
| classifier = pipeline( |
| "text-classification", |
| model=model_name, |
| device=0 if torch.cuda.is_available() else -1 |
| ) |
| print(f"β
Model loaded on {validation_results['device']}") |
| |
| |
| print("\nπ§ͺ Test 1: Critical Functionality") |
| print("-" * 40) |
| |
| critical_cases = [ |
| ("I am happy", "happy"), |
| ("I am angry", "anger"), |
| ("I love this", "love"), |
| ("I am scared", "fear"), |
| ("I am sad", "sadness"), |
| ("What a surprise", "surprise") |
| ] |
| |
| critical_correct = 0 |
| for text, expected in critical_cases: |
| result = classifier(text) |
| predicted = result[0]['label'].lower() |
| is_correct = predicted == expected |
| if is_correct: |
| critical_correct += 1 |
| |
| status = "β
" if is_correct else "β" |
| print(f" {status} '{text}' β {predicted}") |
| |
| critical_accuracy = critical_correct / len(critical_cases) |
| validation_results["tests"]["critical_functionality"] = { |
| "accuracy": critical_accuracy, |
| "passed": critical_accuracy >= 0.8, |
| "details": f"{critical_correct}/{len(critical_cases)} correct" |
| } |
| |
| print(f" π Critical Accuracy: {critical_accuracy:.1%}") |
| |
| |
| print("\nπ§ͺ Test 2: Malay Fixes Validation") |
| print("-" * 40) |
| |
| malay_fixes = [ |
| ("Ini adalah hari jadi terbaik", "happy"), |
| ("Terbaik!", "happy"), |
| ("Ini adalah hari yang baik", "happy"), |
| ("Pengalaman terbaik", "happy") |
| ] |
| |
| malay_correct = 0 |
| for text, expected in malay_fixes: |
| result = classifier(text) |
| predicted = result[0]['label'].lower() |
| is_correct = predicted == expected |
| if is_correct: |
| malay_correct += 1 |
| |
| status = "β
" if is_correct else "β" |
| print(f" {status} '{text}' β {predicted}") |
| |
| malay_accuracy = malay_correct / len(malay_fixes) |
| validation_results["tests"]["malay_fixes"] = { |
| "accuracy": malay_accuracy, |
| "passed": malay_accuracy >= 0.8, |
| "details": f"{malay_correct}/{len(malay_fixes)} correct" |
| } |
| |
| print(f" π Malay Fixes Accuracy: {malay_accuracy:.1%}") |
| |
| |
| print("\nπ§ͺ Test 3: Performance Benchmark") |
| print("-" * 40) |
| |
| benchmark_texts = ["I am happy"] * 20 |
| |
| start_time = time.time() |
| for text in benchmark_texts: |
| _ = classifier(text) |
| end_time = time.time() |
| |
| total_time = end_time - start_time |
| predictions_per_second = len(benchmark_texts) / total_time |
| |
| validation_results["tests"]["performance"] = { |
| "predictions_per_second": predictions_per_second, |
| "passed": predictions_per_second >= 3.0, |
| "details": f"{predictions_per_second:.1f} predictions/second" |
| } |
| |
| print(f" β‘ Speed: {predictions_per_second:.1f} predictions/second") |
| |
| |
| print("\nπ§ͺ Test 4: Confidence Validation") |
| print("-" * 40) |
| |
| confidence_cases = [ |
| "I am extremely happy today!", |
| "I absolutely love this!", |
| "I am terrified!", |
| "Saya sangat gembira!", |
| "Terbaik!" |
| ] |
| |
| high_confidence_count = 0 |
| total_confidence = 0 |
| |
| for text in confidence_cases: |
| result = classifier(text) |
| confidence = result[0]['score'] |
| total_confidence += confidence |
| |
| if confidence > 0.8: |
| high_confidence_count += 1 |
| |
| print(f" π '{text[:30]}...' β {confidence:.1%}") |
| |
| avg_confidence = total_confidence / len(confidence_cases) |
| high_confidence_rate = high_confidence_count / len(confidence_cases) |
| |
| validation_results["tests"]["confidence"] = { |
| "average_confidence": avg_confidence, |
| "high_confidence_rate": high_confidence_rate, |
| "passed": avg_confidence >= 0.7 and high_confidence_rate >= 0.6, |
| "details": f"Avg: {avg_confidence:.1%}, High: {high_confidence_rate:.1%}" |
| } |
| |
| print(f" π Average Confidence: {avg_confidence:.1%}") |
| print(f" π High Confidence Rate: {high_confidence_rate:.1%}") |
| |
| |
| print("\nπ― VALIDATION SUMMARY") |
| print("=" * 60) |
| |
| all_tests_passed = all(test["passed"] for test in validation_results["tests"].values()) |
| |
| if all_tests_passed: |
| validation_results["overall_status"] = "PASS" |
| print("π VALIDATION PASSED!") |
| print("β
All tests passed successfully") |
| print("β
Model is ready for production use") |
| else: |
| validation_results["overall_status"] = "FAIL" |
| print("β VALIDATION FAILED!") |
| print("β οΈ Some tests did not meet requirements") |
| |
| failed_tests = [name for name, test in validation_results["tests"].items() if not test["passed"]] |
| print(f"β Failed tests: {', '.join(failed_tests)}") |
| |
| |
| print("\nπ DETAILED RESULTS:") |
| for test_name, test_result in validation_results["tests"].items(): |
| status = "β
PASS" if test_result["passed"] else "β FAIL" |
| print(f" {status} {test_name.replace('_', ' ').title()}: {test_result['details']}") |
| |
| |
| if output_file: |
| with open(output_file, 'w') as f: |
| json.dump(validation_results, f, indent=2) |
| print(f"\nπΎ Results saved to: {output_file}") |
| |
| return validation_results |
| |
| except Exception as e: |
| print(f"β Validation failed with error: {e}") |
| validation_results["overall_status"] = "ERROR" |
| validation_results["error"] = str(e) |
| return validation_results |
|
|
| def main(): |
| """Main validation function""" |
| parser = argparse.ArgumentParser(description="Validate the multilingual emotion classifier") |
| parser.add_argument( |
| "--model", |
| default="rmtariq/multilingual-emotion-classifier", |
| help="Model name or path to validate" |
| ) |
| parser.add_argument( |
| "--output", |
| help="Output file for validation results (JSON format)" |
| ) |
| |
| args = parser.parse_args() |
| |
| results = validate_model(args.model, args.output) |
| |
| |
| if results["overall_status"] == "PASS": |
| return 0 |
| else: |
| return 1 |
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|