|
|
""" |
|
|
Test LDARiskDiscovery.get_risk_labels() fix |
|
|
""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("Testing LDARiskDiscovery.get_risk_labels() method") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
try: |
|
|
from risk_discovery import LDARiskDiscovery |
|
|
|
|
|
print("\n1. Creating LDARiskDiscovery instance...") |
|
|
lda = LDARiskDiscovery(n_clusters=3) |
|
|
print(" β
Instance created successfully") |
|
|
|
|
|
print("\n2. Testing with sample clauses...") |
|
|
sample_clauses = [ |
|
|
"The party shall indemnify and hold harmless the other party.", |
|
|
"This agreement shall be governed by the laws of the state.", |
|
|
"Payment shall be made within 30 days of invoice date." |
|
|
] |
|
|
|
|
|
print("\n3. Discovering risk patterns...") |
|
|
results = lda.discover_risk_patterns(sample_clauses) |
|
|
print(f" β
Discovered {len(lda.discovered_patterns)} patterns") |
|
|
|
|
|
print("\n4. Testing get_risk_labels() method...") |
|
|
test_clauses = [ |
|
|
"The company agrees to indemnify all damages.", |
|
|
"This contract is subject to California law." |
|
|
] |
|
|
|
|
|
labels = lda.get_risk_labels(test_clauses) |
|
|
print(f" β
get_risk_labels() returned: {labels}") |
|
|
print(f" β
Labels type: {type(labels)}") |
|
|
print(f" β
Number of labels: {len(labels)}") |
|
|
|
|
|
|
|
|
for i, label in enumerate(labels): |
|
|
assert isinstance(label, int), f"Label {i} is not an integer: {type(label)}" |
|
|
assert 0 <= label < 3, f"Label {i} out of range: {label}" |
|
|
|
|
|
print("\n5. Testing get_topic_distribution() method...") |
|
|
dist = lda.get_topic_distribution(test_clauses) |
|
|
print(f" β
Distribution shape: {dist.shape}") |
|
|
print(f" β
Distribution sum per doc: {dist.sum(axis=1)}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("π ALL TESTS PASSED!") |
|
|
print("=" * 60) |
|
|
print("\nβ
LDARiskDiscovery.get_risk_labels() is working correctly") |
|
|
print("β
Ready to run: python3 train.py") |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"\nβ Import error: {e}") |
|
|
print("Make sure all required modules are installed.") |
|
|
exit(1) |
|
|
|
|
|
except AttributeError as e: |
|
|
print(f"\nβ Attribute error: {e}") |
|
|
print("The get_risk_labels method may be missing or incorrect.") |
|
|
exit(1) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
exit(1) |
|
|
|