Trans_for_doctors / tests /test_exceptions.py
Mintik24's picture
asd
b216c95
"""
Unit tests for common.exceptions module
Tests all custom exception classes and their attributes.
"""
import pytest
from common import (
MedicalTranscriberException,
AudioFileException,
TranscriptionException,
CorrectionException,
ReportGenerationException,
ConfigurationException,
APIException,
ValidationException,
KnowledgeBaseException
)
class TestMedicalTranscriberException:
"""Test cases for base MedicalTranscriberException"""
def test_is_exception(self):
"""Should be an Exception subclass"""
exc = MedicalTranscriberException("Test error")
assert isinstance(exc, Exception)
def test_with_message(self):
"""Should accept message"""
msg = "Test error message"
exc = MedicalTranscriberException(msg)
assert str(exc) == msg
def test_inheritance(self):
"""All specific exceptions should inherit from base"""
exceptions = [
AudioFileException("path", "message"),
TranscriptionException("message"),
CorrectionException("message"),
ReportGenerationException("message"),
ConfigurationException("message"),
APIException("endpoint", 404, "message"),
ValidationException("field", "value", "message"),
KnowledgeBaseException("message")
]
for exc in exceptions:
assert isinstance(exc, MedicalTranscriberException)
class TestAudioFileException:
"""Test cases for AudioFileException"""
def test_with_default_message(self):
"""Should use default message if not provided"""
exc = AudioFileException("/path/to/file.wav")
assert "/path/to/file.wav" in str(exc)
assert "Invalid audio file" in str(exc)
def test_with_custom_message(self):
"""Should use custom message if provided"""
exc = AudioFileException("/path/to/file.wav", "File is corrupted")
assert "File is corrupted" in str(exc)
assert "/path/to/file.wav" in str(exc)
def test_file_path_attribute(self):
"""Should have file_path attribute"""
file_path = "/path/to/file.wav"
exc = AudioFileException(file_path, "Test")
assert exc.file_path == file_path
def test_message_attribute(self):
"""Should have message attribute"""
exc = AudioFileException("/path/to/file.wav", "Test message")
assert "Test message" in exc.message
class TestTranscriptionException:
"""Test cases for TranscriptionException"""
def test_basic_usage(self):
"""Should work with simple message"""
exc = TranscriptionException("Transcription failed")
assert "Transcription failed" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = TranscriptionException("Test")
assert isinstance(exc, MedicalTranscriberException)
class TestCorrectionException:
"""Test cases for CorrectionException"""
def test_basic_usage(self):
"""Should work with simple message"""
exc = CorrectionException("LLM correction failed")
assert "LLM correction failed" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = CorrectionException("Test")
assert isinstance(exc, MedicalTranscriberException)
class TestReportGenerationException:
"""Test cases for ReportGenerationException"""
def test_basic_usage(self):
"""Should work with simple message"""
exc = ReportGenerationException("Report generation failed")
assert "Report generation failed" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = ReportGenerationException("Test")
assert isinstance(exc, MedicalTranscriberException)
class TestConfigurationException:
"""Test cases for ConfigurationException"""
def test_basic_usage(self):
"""Should work with simple message"""
exc = ConfigurationException("Invalid configuration")
assert "Invalid configuration" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = ConfigurationException("Test")
assert isinstance(exc, MedicalTranscriberException)
class TestAPIException:
"""Test cases for APIException with status codes"""
def test_with_status_code_and_message(self):
"""Should capture endpoint, status code, and message"""
exc = APIException("https://api.example.com/v1/chat", 429, "Rate limit exceeded")
assert exc.endpoint == "https://api.example.com/v1/chat"
assert exc.status_code == 429
assert "429" in str(exc)
assert "Rate limit" in str(exc)
def test_error_400(self):
"""Should handle 400 Bad Request"""
exc = APIException("api/endpoint", 400, "Bad request format")
assert exc.status_code == 400
assert "400" in str(exc)
def test_error_401(self):
"""Should handle 401 Unauthorized"""
exc = APIException("api/endpoint", 401, "Invalid API key")
assert exc.status_code == 401
assert "401" in str(exc)
assert "Invalid API key" in str(exc)
def test_error_429(self):
"""Should handle 429 Rate Limit"""
exc = APIException("api/endpoint", 429, "Rate limit exceeded")
assert exc.status_code == 429
assert "429" in str(exc)
def test_error_500(self):
"""Should handle 500 Internal Server Error"""
exc = APIException("api/endpoint", 500, "Server error")
assert exc.status_code == 500
assert "500" in str(exc)
def test_message_attribute(self):
"""Should have message attribute with full context"""
exc = APIException("api/endpoint", 404, "Not found")
assert "404" in exc.message
assert "api/endpoint" in exc.message
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = APIException("api", 500, "error")
assert isinstance(exc, MedicalTranscriberException)
class TestValidationException:
"""Test cases for ValidationException with field context"""
def test_with_field_name(self):
"""Should capture field name"""
exc = ValidationException("email", "invalid@", "Invalid email format")
assert exc.field == "email"
assert exc.value == "invalid@"
assert "email" in str(exc)
def test_default_reason(self):
"""Should use default reason if not provided"""
exc = ValidationException("username", "ab", "")
assert "Invalid value" in str(exc)
def test_custom_reason(self):
"""Should use custom reason"""
exc = ValidationException("age", "-5", "Age must be positive")
assert "Age must be positive" in str(exc)
def test_audio_file_field(self):
"""Should work with audio_file field"""
exc = ValidationException("audio_file", "test.xyz", "Unsupported format")
assert "audio_file" in str(exc)
assert "test.xyz" in str(exc)
def test_api_key_field(self):
"""Should work with api_key field"""
exc = ValidationException("api_key", "***", "API key is too short")
assert "api_key" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = ValidationException("field", "value", "reason")
assert isinstance(exc, MedicalTranscriberException)
class TestKnowledgeBaseException:
"""Test cases for KnowledgeBaseException"""
def test_basic_usage(self):
"""Should work with simple message"""
exc = KnowledgeBaseException("Knowledge base not found")
assert "Knowledge base not found" in str(exc)
def test_inheritance_chain(self):
"""Should be MedicalTranscriberException"""
exc = KnowledgeBaseException("Test")
assert isinstance(exc, MedicalTranscriberException)
class TestExceptionHandling:
"""Integration tests for exception handling"""
def test_catch_api_exception_by_status_code(self):
"""Should be able to catch and handle by status code"""
exc = APIException("api/chat", 429, "Rate limit")
try:
raise exc
except APIException as e:
if e.status_code == 429:
assert True
else:
assert False
def test_catch_specific_exceptions(self):
"""Should be able to catch specific exception types"""
exceptions_to_test = [
(AudioFileException("/path", "test"), AudioFileException),
(TranscriptionException("test"), TranscriptionException),
(CorrectionException("test"), CorrectionException),
(APIException("api", 500, "error"), APIException),
(ValidationException("field", "value", "reason"), ValidationException),
]
for exc, exc_type in exceptions_to_test:
try:
raise exc
except exc_type:
assert True
except Exception:
assert False
def test_catch_all_as_medical_transcriber_exception(self):
"""Should be able to catch all as base exception"""
exceptions = [
AudioFileException("/path", "test"),
TranscriptionException("test"),
CorrectionException("test"),
APIException("api", 500, "error"),
ValidationException("field", "value", "reason"),
KnowledgeBaseException("test"),
]
for exc in exceptions:
try:
raise exc
except MedicalTranscriberException:
assert True
except Exception:
assert False
def test_exception_chain_preservation(self):
"""Should preserve exception chain"""
try:
try:
raise ValueError("Original error")
except ValueError as e:
raise AudioFileException("/path", str(e)) from e
except AudioFileException as e:
assert str(e.file_path) == "/path"
def test_multiple_exception_handlers(self):
"""Should work with multiple exception handlers"""
def test_api_error():
raise APIException("api", 429, "Rate limited")
def test_validation_error():
raise ValidationException("field", "value", "Invalid")
try:
test_api_error()
except APIException as e:
assert e.status_code == 429
except Exception:
assert False
try:
test_validation_error()
except ValidationException as e:
assert e.field == "field"
except Exception:
assert False
class TestExceptionStringRepresentation:
"""Test string representations of exceptions"""
def test_audio_file_exception_string(self):
"""Should have informative string representation"""
exc = AudioFileException("/path/to/audio.wav", "Corrupted file")
exc_str = str(exc)
assert "Corrupted file" in exc_str
assert "/path/to/audio.wav" in exc_str
def test_api_exception_string(self):
"""Should have informative string representation"""
exc = APIException("api/v1/chat", 401, "Unauthorized")
exc_str = str(exc)
assert "401" in exc_str
assert "api/v1/chat" in exc_str
assert "Unauthorized" in exc_str
def test_validation_exception_string(self):
"""Should have informative string representation"""
exc = ValidationException("username", "admin", "Username reserved")
exc_str = str(exc)
assert "username" in exc_str
assert "admin" in exc_str
assert "Username reserved" in exc_str