code2-repo / test_lda_get_labels.py
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
"""
Test LDARiskDiscovery.get_risk_labels() fix
"""
print("=" * 60)
print("Testing LDARiskDiscovery.get_risk_labels() method")
print("=" * 60)
# Test that the method exists and works correctly
try:
from risk_discovery import LDARiskDiscovery
print("\n1. Creating LDARiskDiscovery instance...")
lda = LDARiskDiscovery(n_clusters=3)
print(" βœ… Instance created successfully")
print("\n2. Testing with sample clauses...")
sample_clauses = [
"The party shall indemnify and hold harmless the other party.",
"This agreement shall be governed by the laws of the state.",
"Payment shall be made within 30 days of invoice date."
]
print("\n3. Discovering risk patterns...")
results = lda.discover_risk_patterns(sample_clauses)
print(f" βœ… Discovered {len(lda.discovered_patterns)} patterns")
print("\n4. Testing get_risk_labels() method...")
test_clauses = [
"The company agrees to indemnify all damages.",
"This contract is subject to California law."
]
labels = lda.get_risk_labels(test_clauses)
print(f" βœ… get_risk_labels() returned: {labels}")
print(f" βœ… Labels type: {type(labels)}")
print(f" βœ… Number of labels: {len(labels)}")
# Verify labels are integers
for i, label in enumerate(labels):
assert isinstance(label, int), f"Label {i} is not an integer: {type(label)}"
assert 0 <= label < 3, f"Label {i} out of range: {label}"
print("\n5. Testing get_topic_distribution() method...")
dist = lda.get_topic_distribution(test_clauses)
print(f" βœ… Distribution shape: {dist.shape}")
print(f" βœ… Distribution sum per doc: {dist.sum(axis=1)}")
print("\n" + "=" * 60)
print("πŸŽ‰ ALL TESTS PASSED!")
print("=" * 60)
print("\nβœ… LDARiskDiscovery.get_risk_labels() is working correctly")
print("βœ… Ready to run: python3 train.py")
except ImportError as e:
print(f"\n❌ Import error: {e}")
print("Make sure all required modules are installed.")
exit(1)
except AttributeError as e:
print(f"\n❌ Attribute error: {e}")
print("The get_risk_labels method may be missing or incorrect.")
exit(1)
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
exit(1)