hatespeech-detection / tests /test_model_load.py
jl
feat(file upload): added file upload and updated file structure to be readeable
198730e
#!/usr/bin/env python
"""Test script to check if model loads without state_dict errors."""
from hatespeech_model import load_model_from_hf
print("Testing model load...")
print("=" * 60)
try:
print("\n1. Loading Altered Shield model...")
altered_model, altered_tokenizer, _, _, _ = load_model_from_hf("altered")
print("✅ Altered Shield loaded successfully!")
print(f" Model type: {type(altered_model).__name__}")
# Test forward pass with dummy input
import torch
test_input = torch.randint(0, 1000, (1, 50))
test_attn_mask = torch.ones(1, 50)
print("\n2. Testing forward pass...")
with torch.no_grad():
logits, rationale_probs, selector_logits, attns = altered_model(
input_ids=test_input,
attention_mask=test_attn_mask,
additional_input_ids=test_input,
additional_attention_mask=test_attn_mask
)
print(f"✅ Forward pass successful!")
print(f" Logits shape: {logits.shape}")
print(f" Output range: [{logits.min():.4f}, {logits.max():.4f}]")
print("\n" + "=" * 60)
print("✅ All tests passed!")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("❌ Tests failed")